From 5818de7c724a8ac9c414328cd3d259ec978ff74a Mon Sep 17 00:00:00 2001 From: Tpt Date: Tue, 31 Jan 2023 18:27:41 +0100 Subject: [PATCH 1/4] Use SSE2 SIMD to accelerate IndexTable::find_entry --- src/index.rs | 68 +++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/src/index.rs b/src/index.rs index a6114572..9a00d5c2 100644 --- a/src/index.rs +++ b/src/index.rs @@ -11,6 +11,10 @@ use crate::{ table::{key::TableKey, SIZE_TIERS_BITS}, Key, }; +#[cfg(target_arch = "x86")] +use std::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; use std::convert::TryInto; // Index chunk consists of 8 64-bit entries. @@ -231,8 +235,70 @@ impl IndexTable { Ok(try_io!(Ok(&map[offset..offset + CHUNK_LEN]))) } - #[inline(never)] fn find_entry(&self, key_prefix: u64, sub_index: usize, chunk: &[u8]) -> (Entry, usize) { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + if is_x86_feature_detected!("sse2") { + return self.find_entry_sse2(key_prefix, sub_index, chunk) + } + self.find_entry_regular(key_prefix, sub_index, chunk) + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + fn find_entry_sse2(&self, key_prefix: u64, sub_index: usize, chunk: &[u8]) -> (Entry, usize) { + assert!(chunk.len() >= CHUNK_ENTRIES * 8); // Bound checking (not done by SIMD instructions) + debug_assert!( + Entry::address_bits(self.id.index_bits()) <= 32, + "To be sure we can use all high 32 bits as key prefix" + ); + debug_assert_eq!( + CHUNK_ENTRIES % 4, + 0, + "We assume here we got buffer with a number of elements that is a multiple of 4" + ); + + unsafe { + let target = _mm_set1_epi32(((key_prefix << self.id.index_bits()) >> 32) as i32); + let mut i = (sub_index >> 2) << 2; // We keep an alignment of 4 + while i + 4 <= CHUNK_ENTRIES { + // We load the value 2 by 2 and move the high bits into the low part of the register + let first_two = _mm_shuffle_epi32::<0b10001101>(_mm_loadu_si128( + chunk[i * 8..].as_ptr() as *const __m128i, + )); + let last_two = _mm_shuffle_epi32::<0b10001101>(_mm_loadu_si128( + chunk[(i + 2) * 8..].as_ptr() as *const __m128i, + )); + // We set into current the input low parts in the interleaved order + let current = _mm_unpacklo_epi32(first_two, last_two); + let cmp = _mm_movemask_epi8(_mm_cmpeq_epi32(current, target)); + if cmp != 0 { + let position = i + if cmp & 0x000f != 0 { + 0 + } else if cmp & 0x00f0 != 0 { + 2 + } else if cmp & 0x0f00 != 0 { + 1 + } else if cmp & 0xf000 != 0 { + 3 + } else { + unreachable!() + }; + if position >= sub_index { + // We need to check we are not reading again the same input + return (Self::read_entry(chunk, position), position) + } + } + i += 4; + } + } + (Entry::empty(), 0) + } + + fn find_entry_regular( + &self, + key_prefix: u64, + sub_index: usize, + chunk: &[u8], + ) -> (Entry, usize) { assert!(chunk.len() >= CHUNK_ENTRIES * 8); let partial_key = Entry::extract_key(key_prefix, self.id.index_bits()); for i in sub_index..CHUNK_ENTRIES { From f8df837041ed0e10a92aa427c80b577ab7b9b172 Mon Sep 17 00:00:00 2001 From: Tpt Date: Thu, 2 Feb 2023 08:25:17 +0100 Subject: [PATCH 2/4] Use _mm_unpacklo_epi64 instead of _mm_unpacklo_epi32 --- src/index.rs | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/index.rs b/src/index.rs index 9a00d5c2..c4b3af37 100644 --- a/src/index.rs +++ b/src/index.rs @@ -267,21 +267,11 @@ impl IndexTable { let last_two = _mm_shuffle_epi32::<0b10001101>(_mm_loadu_si128( chunk[(i + 2) * 8..].as_ptr() as *const __m128i, )); - // We set into current the input low parts in the interleaved order - let current = _mm_unpacklo_epi32(first_two, last_two); + // We set into current the input low parts + let current = _mm_unpacklo_epi64(first_two, last_two); let cmp = _mm_movemask_epi8(_mm_cmpeq_epi32(current, target)); if cmp != 0 { - let position = i + if cmp & 0x000f != 0 { - 0 - } else if cmp & 0x00f0 != 0 { - 2 - } else if cmp & 0x0f00 != 0 { - 1 - } else if cmp & 0xf000 != 0 { - 3 - } else { - unreachable!() - }; + let position = i + (cmp.trailing_zeros() as usize) / 4; if position >= sub_index { // We need to check we are not reading again the same input return (Self::read_entry(chunk, position), position) From a976dd45958baa29a1acba0d15aa9487e0ff1f14 Mon Sep 17 00:00:00 2001 From: Tpt Date: Thu, 2 Feb 2023 08:34:02 +0100 Subject: [PATCH 3/4] Uses static target detection --- src/index.rs | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/src/index.rs b/src/index.rs index c4b3af37..0bf5f9c1 100644 --- a/src/index.rs +++ b/src/index.rs @@ -235,24 +235,15 @@ impl IndexTable { Ok(try_io!(Ok(&map[offset..offset + CHUNK_LEN]))) } + #[cfg(target_feature = "sse2")] fn find_entry(&self, key_prefix: u64, sub_index: usize, chunk: &[u8]) -> (Entry, usize) { - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("sse2") { - return self.find_entry_sse2(key_prefix, sub_index, chunk) - } - self.find_entry_regular(key_prefix, sub_index, chunk) - } - - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - fn find_entry_sse2(&self, key_prefix: u64, sub_index: usize, chunk: &[u8]) -> (Entry, usize) { assert!(chunk.len() >= CHUNK_ENTRIES * 8); // Bound checking (not done by SIMD instructions) debug_assert!( Entry::address_bits(self.id.index_bits()) <= 32, "To be sure we can use all high 32 bits as key prefix" ); - debug_assert_eq!( - CHUNK_ENTRIES % 4, - 0, + const _: () = assert!( + CHUNK_ENTRIES % 4 == 0, "We assume here we got buffer with a number of elements that is a multiple of 4" ); @@ -283,12 +274,8 @@ impl IndexTable { (Entry::empty(), 0) } - fn find_entry_regular( - &self, - key_prefix: u64, - sub_index: usize, - chunk: &[u8], - ) -> (Entry, usize) { + #[cfg(not(target_feature = "sse2"))] + fn find_entry(&self, key_prefix: u64, sub_index: usize, chunk: &[u8]) -> (Entry, usize) { assert!(chunk.len() >= CHUNK_ENTRIES * 8); let partial_key = Entry::extract_key(key_prefix, self.id.index_bits()); for i in sub_index..CHUNK_ENTRIES { From e5556048612a730ab08ae06716deb7a03861a78b Mon Sep 17 00:00:00 2001 From: Tpt Date: Thu, 2 Feb 2023 21:26:53 +0100 Subject: [PATCH 4/4] IndexTable::find_entry: support address bigger than 32bits --- src/index.rs | 71 +++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 14 deletions(-) diff --git a/src/index.rs b/src/index.rs index 0bf5f9c1..0c1aac2c 100644 --- a/src/index.rs +++ b/src/index.rs @@ -15,7 +15,7 @@ use crate::{ use std::arch::x86::*; #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; -use std::convert::TryInto; +use std::{cmp::max, convert::TryInto}; // Index chunk consists of 8 64-bit entries. const CHUNK_LEN: usize = CHUNK_ENTRIES * ENTRY_BYTES; // 512 bytes @@ -235,28 +235,38 @@ impl IndexTable { Ok(try_io!(Ok(&map[offset..offset + CHUNK_LEN]))) } - #[cfg(target_feature = "sse2")] fn find_entry(&self, key_prefix: u64, sub_index: usize, chunk: &[u8]) -> (Entry, usize) { + if cfg!(target_feature = "sse2") { + self.find_entry_sse2(key_prefix, sub_index, chunk) + } else { + self.find_entry_base(key_prefix, sub_index, chunk) + } + } + + #[cfg(target_feature = "sse2")] + fn find_entry_sse2(&self, key_prefix: u64, sub_index: usize, chunk: &[u8]) -> (Entry, usize) { assert!(chunk.len() >= CHUNK_ENTRIES * 8); // Bound checking (not done by SIMD instructions) - debug_assert!( - Entry::address_bits(self.id.index_bits()) <= 32, - "To be sure we can use all high 32 bits as key prefix" - ); const _: () = assert!( CHUNK_ENTRIES % 4 == 0, "We assume here we got buffer with a number of elements that is a multiple of 4" ); + let shift = max(32, Entry::address_bits(self.id.index_bits())); unsafe { - let target = _mm_set1_epi32(((key_prefix << self.id.index_bits()) >> 32) as i32); + let target = _mm_set1_epi32(((key_prefix << self.id.index_bits()) >> shift) as i32); + let shift_mask = _mm_set_epi64x(0, shift.into()); let mut i = (sub_index >> 2) << 2; // We keep an alignment of 4 while i + 4 <= CHUNK_ENTRIES { - // We load the value 2 by 2 and move the high bits into the low part of the register - let first_two = _mm_shuffle_epi32::<0b10001101>(_mm_loadu_si128( - chunk[i * 8..].as_ptr() as *const __m128i, + // We load the value 2 by 2 + // Then we remove the address by shifting such that the partial key is in the low + // part + let first_two = _mm_shuffle_epi32::<0b11011000>(_mm_srl_epi64( + _mm_loadu_si128(chunk[i * 8..].as_ptr() as *const __m128i), + shift_mask, )); - let last_two = _mm_shuffle_epi32::<0b10001101>(_mm_loadu_si128( - chunk[(i + 2) * 8..].as_ptr() as *const __m128i, + let last_two = _mm_shuffle_epi32::<0b11011000>(_mm_srl_epi64( + _mm_loadu_si128(chunk[(i + 2) * 8..].as_ptr() as *const __m128i), + shift_mask, )); // We set into current the input low parts let current = _mm_unpacklo_epi64(first_two, last_two); @@ -274,8 +284,7 @@ impl IndexTable { (Entry::empty(), 0) } - #[cfg(not(target_feature = "sse2"))] - fn find_entry(&self, key_prefix: u64, sub_index: usize, chunk: &[u8]) -> (Entry, usize) { + fn find_entry_base(&self, key_prefix: u64, sub_index: usize, chunk: &[u8]) -> (Entry, usize) { assert!(chunk.len() >= CHUNK_ENTRIES * 8); let partial_key = Entry::extract_key(key_prefix, self.id.index_bits()); for i in sub_index..CHUNK_ENTRIES { @@ -575,6 +584,7 @@ impl IndexTable { #[cfg(test)] mod test { use super::*; + use std::path::PathBuf; #[test] fn test_entries() { @@ -595,4 +605,37 @@ mod test { assert!(IndexTable::transmute_chunk(chunk2) == chunk); } + + #[test] + fn test_find_entries() { + let partial_keys = [1, 1 << 10, 1 << 20]; + for index_bits in [16, 18, 20, 22] { + let index_table = IndexTable { + id: TableId(index_bits.into()), + map: RwLock::new(None), + path: PathBuf::new(), + }; + + let data_address = Address::from_u64((1 << index_bits) - 1); + + let mut chunk = [0; CHUNK_ENTRIES * 8]; + for (i, partial_key) in partial_keys.iter().enumerate() { + chunk[i * 8..(i + 1) * 8].copy_from_slice( + &Entry::new(data_address, *partial_key, index_bits).as_u64().to_le_bytes(), + ); + } + + for partial_key in &partial_keys { + let key_prefix = *partial_key << (CHUNK_ENTRIES_BITS + SIZE_TIERS_BITS); + assert_eq!( + index_table.find_entry_sse2(key_prefix, 0, &chunk).0.partial_key(index_bits), + *partial_key + ); + assert_eq!( + index_table.find_entry_base(key_prefix, 0, &chunk).0.partial_key(index_bits), + *partial_key + ); + } + } + } }