diff --git a/src/gxhash.rs b/src/gxhash.rs index 901faa1..060d6eb 100644 --- a/src/gxhash.rs +++ b/src/gxhash.rs @@ -47,7 +47,7 @@ mod platform_defs { } #[inline] - pub unsafe fn mix(hash: state) -> state { + pub unsafe fn finalize(hash: state) -> u32 { let salt = vcombine_s64(vcreate_s64(4860325414534694371), vcreate_s64(8120763769363581797)); let keys = vmulq_s32( ReinterpretUnion { int64: salt }.int32, @@ -55,15 +55,8 @@ mod platform_defs { let a = vaeseq_u8(ReinterpretUnion { int8: hash }.uint8, vdupq_n_u8(0)); let b = vaesmcq_u8(a); let c = veorq_u8(b, ReinterpretUnion{ int32: keys }.uint8); - ReinterpretUnion{ uint8: c }.int8 - } - - #[inline] - pub unsafe fn fold(hash: state) -> u32 { - // Bit-cast the int8x16_t to uint32x4_t - let vec_u32: uint32x4_t = ReinterpretUnion { int8: hash }.uint32; - // Get the first u32 value from the vector - vgetq_lane_u32(vec_u32, 3) + let p = &ReinterpretUnion{ uint8: c }.int8 as *const state as *const u32; + *p } } @@ -93,8 +86,10 @@ mod platform_defs { #[inline] pub unsafe fn get_partial(p: *const state, len: isize) -> state { const MASK: [u8; size_of::() * 2] = [ - 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 ]; + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 ]; // Safety check if check_same_page(p) { // false {// @@ -131,22 +126,26 @@ mod platform_defs { } #[inline] - pub unsafe fn mix(hash: state) -> state { - let salt = _mm256_set_epi64x(-4860325414534694371, 8120763769363581797, -4860325414534694371, 8120763769363581797); - //let keys = _mm256_mul_epu32(salt, hash); - _mm256_aesenc_epi128(hash, salt) - } - - #[inline] - pub unsafe fn fold(hash: state) -> u32 { - let p = &hash as *const state as *const u32; - *p ^ *p.offset(1) - ^ *p.offset(2) - ^ *p.offset(3) - ^ *p.offset(4) - ^ *p.offset(5) - ^ *p.offset(6) - ^ *p.offset(7) + #[allow(overflowing_literals)] + pub unsafe fn finalize(hash: state) -> u32 { + // Xor 256 state into 128 bit state for AES + let lower = _mm256_castsi256_si128(hash); + let upper = _mm256_extracti128_si256(hash, 1); + let mut hash = _mm_xor_si128(lower, upper); + + // Hardcoded AES keys + let salt1 = _mm_set_epi32(0x713B01D0, 0x8F2F35DB, 0xAF163956, 0x85459F85); + let salt2 = _mm_set_epi32(0x1DE09647, 0x92CFA39C, 0x3DD99ACA, 0xB89C054F); + let salt3 = _mm_set_epi32(0xC78B122B, 0x5544B1B7, 0x689D2B7D, 0xD0012E32); + + // 3 rounds of AES + hash = _mm_aesenc_si128(hash, salt1); + hash = _mm_aesenc_si128(hash, salt2); + hash = _mm_aesenclast_si128(hash, salt3); + + // Truncate to output hash size + let p = &hash as *const __m128i as *const u32; + *p } } @@ -157,8 +156,7 @@ pub use platform_defs::*; #[cfg(test)] pub static mut COUNTERS : Vec = vec![]; -#[inline] -//#[inline(never)] +#[inline] // To be disabled when profiling pub fn gxhash(input: &[u8]) -> u32 { unsafe { const VECTOR_SIZE: isize = std::mem::size_of::() as isize; @@ -241,7 +239,7 @@ pub fn gxhash(input: &[u8]) -> u32 { hash_vector = compress(hash_vector, partial_vector); } - fold(mix(hash_vector)) + finalize(hash_vector) } } @@ -271,10 +269,10 @@ mod tests { } #[test] - fn hash_of_zero_is_zero() { + fn hash_of_zero_is_not_zero() { let zero_bytes = [0u8; 1200]; let hash = gxhash(&zero_bytes); - assert_eq!(0, hash); + assert_ne!(0, hash); } } \ No newline at end of file