Skip to content

Commit

Permalink
Implement partial using cmp for ARM
Browse files Browse the repository at this point in the history
  • Loading branch information
ogxd committed Oct 20, 2023
1 parent 57dee0a commit 9ca1b1e
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 23 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# gxhash-rust
# gxhash
![CI](https://github.com/ogxd/gxhash-rust/actions/workflows/rust.yml/badge.svg)

Up to this date, the fastest non-cryptographic hashing algorithm
Expand Down
61 changes: 39 additions & 22 deletions src/gxhash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,37 @@ mod platform_defs {

#[inline]
pub unsafe fn get_partial(p: *const state, len: isize) -> state {
const MASK: [u8; size_of::<state>() * 2] = [
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 ];

let mask = vld1q_s8((MASK.as_ptr() as *const i8).offset(size_of::<state>() as isize - len));
let mut vec = vandq_s8(load_unaligned(p), mask);
let partial_vector: state;
if check_same_page(p) {
// Unsafe (hence the check) but much faster
let indices = vld1q_s8([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15].as_ptr());
let mask = vcgtq_s8(vdupq_n_s8(len as i8), indices);
partial_vector = vandq_s8(load_unaligned(p), ReinterpretUnion { uint8: mask }.int8);
} else {
// Safer but slower, using memcpy
partial_vector = get_partial_safe(p as *const i8, len as usize);
}
// Prevents padded zeroes to introduce bias
return vaddq_s8(partial_vector, vdupq_n_s8(len as i8));
}

// To avoid collisions for zero right padded inputs, we mutate this vector using its used length
vec = vaddq_s8(vec, vdupq_n_s8(len as i8));
#[inline]
unsafe fn check_same_page(ptr: *const state) -> bool {
let address = ptr as usize;
// Mask to keep only the last 12 bits (3 bytes)
let offset_within_page = address & 0xFFF;
// Check if the 32nd byte from the current offset exceeds the page boundary
offset_within_page <= (4096 - size_of::<state>() - 1)
}

return vec;
#[inline]
unsafe fn get_partial_safe(data: *const i8, len: usize) -> state {
// Temporary buffer filled with zeros
let mut buffer = [0i8; size_of::<state>()];
// Copy data into the buffer
std::ptr::copy(data, buffer.as_mut_ptr(), len);
// Load the buffer into a __m256i vector
vld1q_s8(buffer.as_ptr())
}

#[inline]
Expand Down Expand Up @@ -73,22 +93,19 @@ mod platform_defs {
}

#[inline]
pub unsafe fn finalize(hash: state) -> u32 {
pub unsafe fn finalize(hash: state, seed: i32) -> state {
// Hardcoded AES keys
let salt1 = vld1q_u32([0x713B01D0, 0x8F2F35DB, 0xAF163956, 0x85459F85].as_ptr());
let salt2 = vld1q_u32([0x1DE09647, 0x92CFA39C, 0x3DD99ACA, 0xB89C054F].as_ptr());
let salt3 = vld1q_u32([0xC78B122B, 0x5544B1B7, 0x689D2B7D, 0xD0012E32].as_ptr());
let keys_1 = vld1q_u32([0x713B01D0, 0x8F2F35DB, 0xAF163956, 0x85459F85].as_ptr());
let keys_2 = vld1q_u32([0x1DE09647, 0x92CFA39C, 0x3DD99ACA, 0xB89C054F].as_ptr());
let keys_3 = vld1q_u32([0xC78B122B, 0x5544B1B7, 0x689D2B7D, 0xD0012E32].as_ptr());

// 3 rounds of AES
let mut hash = ReinterpretUnion{ int8: hash }.uint8;
hash = aes_encrypt(hash, ReinterpretUnion{ uint32: salt1 }.uint8);
hash = aes_encrypt(hash, ReinterpretUnion{ uint32: salt2 }.uint8);
hash = aes_encrypt_last(hash, ReinterpretUnion{ uint32: salt3 }.uint8);
let hash = ReinterpretUnion{ uint8: hash }.int8;

// Truncate to output hash size
let p = &hash as *const state as *const u32;
*p
hash = aes_encrypt(hash, ReinterpretUnion{ int32: vdupq_n_s32(seed) }.uint8);
hash = aes_encrypt(hash, ReinterpretUnion{ uint32: keys_1 }.uint8);
hash = aes_encrypt(hash, ReinterpretUnion{ uint32: keys_2 }.uint8);
hash = aes_encrypt_last(hash, ReinterpretUnion{ uint32: keys_3 }.uint8);
return ReinterpretUnion{ uint8: hash }.int8;
}
}

Expand Down Expand Up @@ -130,7 +147,7 @@ mod platform_defs {
let mask = _mm256_cmpgt_epi8(_mm256_set1_epi8(len as i8), indices);
partial_vector = _mm256_and_si256(_mm256_loadu_si256(p), mask);
} else {
partial_vector = get_partial_safe(p as *const u8, len as usize)
partial_vector = get_partial_safe(p as *const u8, len as usize)
}
// Prevents padded zeroes to introduce bias
_mm256_add_epi32(partial_vector, _mm256_set1_epi32(len as i32))
Expand Down

0 comments on commit 9ca1b1e

Please sign in to comment.