From b2434ff48cf103682e618bbb32e93e2b592cb635 Mon Sep 17 00:00:00 2001 From: usamoi Date: Fri, 30 Aug 2024 00:49:57 +0800 Subject: [PATCH] feat: fallback version of fast scan Signed-off-by: usamoi --- crates/quantization/src/fast_scan/b4.rs | 600 +++++++++++++----------- crates/quantization/src/product/mod.rs | 10 +- crates/quantization/src/scalar/mod.rs | 10 +- crates/rabitq/src/quant/quantizer.rs | 10 +- 4 files changed, 341 insertions(+), 289 deletions(-) diff --git a/crates/quantization/src/fast_scan/b4.rs b/crates/quantization/src/fast_scan/b4.rs index d90f1fd88..7af0620d3 100644 --- a/crates/quantization/src/fast_scan/b4.rs +++ b/crates/quantization/src/fast_scan/b4.rs @@ -58,319 +58,371 @@ pub fn pack(width: u32, r: [Vec; 32]) -> impl Iterator { }) } -pub fn is_supported() -> bool { +mod fast_scan_b4 { #[cfg(target_arch = "x86_64")] - { - if detect::v4::detect() { - return true; - } - if detect::v3::detect() { - return true; - } - if detect::v2::detect() { - return true; - } - } - false -} - -pub fn fast_scan(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { - #[cfg(target_arch = "x86_64")] - { - if detect::v4::detect() { - return unsafe { fast_scan_v4(width, codes, lut) }; - } - if detect::v3::detect() { - return unsafe { fast_scan_v3(width, codes, lut) }; - } - if detect::v2::detect() { - return unsafe { fast_scan_v2(width, codes, lut) }; - } - } - let _ = (width, codes, lut); - unimplemented!() -} - -#[cfg(target_arch = "x86_64")] -#[detect::target_cpu(enable = "v4")] -unsafe fn fast_scan_v4(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { - // bounds checking is not enforced by compiler, so check it manually - assert_eq!(codes.len(), width as usize * 16); - assert_eq!(lut.len(), width as usize * 16); - - unsafe { - use std::arch::x86_64::*; - - #[inline] - #[detect::target_cpu(enable = "v4")] - unsafe fn combine2x2(x0x1: __m256i, y0y1: __m256i) -> __m256i { - unsafe { - let x1y0 = _mm256_permute2f128_si256(x0x1, y0y1, 0x21); - let x0y1 = _mm256_blend_epi32(x0x1, y0y1, 0xf0); - _mm256_add_epi16(x1y0, x0y1) + #[detect::target_cpu(enable = "v4")] + unsafe fn fast_scan_b4_v4(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { + // bounds checking is not enforced by compiler, so check it manually + assert_eq!(codes.len(), width as usize * 16); + assert_eq!(lut.len(), width as usize * 16); + + unsafe { + use std::arch::x86_64::*; + + #[inline] + #[detect::target_cpu(enable = "v4")] + unsafe fn combine2x2(x0x1: __m256i, y0y1: __m256i) -> __m256i { + unsafe { + let x1y0 = _mm256_permute2f128_si256(x0x1, y0y1, 0x21); + let x0y1 = _mm256_blend_epi32(x0x1, y0y1, 0xf0); + _mm256_add_epi16(x1y0, x0y1) + } } - } - #[inline] - #[detect::target_cpu(enable = "v4")] - unsafe fn combine4x2(x0x1x2x3: __m512i, y0y1y2y3: __m512i) -> __m256i { - unsafe { - let x0x1 = _mm512_castsi512_si256(x0x1x2x3); - let x2x3 = _mm512_extracti64x4_epi64(x0x1x2x3, 1); - let y0y1 = _mm512_castsi512_si256(y0y1y2y3); - let y2y3 = _mm512_extracti64x4_epi64(y0y1y2y3, 1); - let x01y01 = combine2x2(x0x1, y0y1); - let x23y23 = combine2x2(x2x3, y2y3); - _mm256_add_epi16(x01y01, x23y23) + #[inline] + #[detect::target_cpu(enable = "v4")] + unsafe fn combine4x2(x0x1x2x3: __m512i, y0y1y2y3: __m512i) -> __m256i { + unsafe { + let x0x1 = _mm512_castsi512_si256(x0x1x2x3); + let x2x3 = _mm512_extracti64x4_epi64(x0x1x2x3, 1); + let y0y1 = _mm512_castsi512_si256(y0y1y2y3); + let y2y3 = _mm512_extracti64x4_epi64(y0y1y2y3, 1); + let x01y01 = combine2x2(x0x1, y0y1); + let x23y23 = combine2x2(x2x3, y2y3); + _mm256_add_epi16(x01y01, x23y23) + } } - } - let mut accu_0 = _mm512_setzero_si512(); - let mut accu_1 = _mm512_setzero_si512(); - let mut accu_2 = _mm512_setzero_si512(); - let mut accu_3 = _mm512_setzero_si512(); + let mut accu_0 = _mm512_setzero_si512(); + let mut accu_1 = _mm512_setzero_si512(); + let mut accu_2 = _mm512_setzero_si512(); + let mut accu_3 = _mm512_setzero_si512(); - let mut i = 0_usize; - while i + 4 <= width as usize { - let c = _mm512_loadu_si512(codes.as_ptr().add(i * 16).cast()); + let mut i = 0_usize; + while i + 4 <= width as usize { + let c = _mm512_loadu_si512(codes.as_ptr().add(i * 16).cast()); - let mask = _mm512_set1_epi8(0xf); - let clo = _mm512_and_si512(c, mask); - let chi = _mm512_and_si512(_mm512_srli_epi16(c, 4), mask); + let mask = _mm512_set1_epi8(0xf); + let clo = _mm512_and_si512(c, mask); + let chi = _mm512_and_si512(_mm512_srli_epi16(c, 4), mask); - let lut = _mm512_loadu_si512(lut.as_ptr().add(i * 16).cast()); - let res_lo = _mm512_shuffle_epi8(lut, clo); - accu_0 = _mm512_add_epi16(accu_0, res_lo); - accu_1 = _mm512_add_epi16(accu_1, _mm512_srli_epi16(res_lo, 8)); - let res_hi = _mm512_shuffle_epi8(lut, chi); - accu_2 = _mm512_add_epi16(accu_2, res_hi); - accu_3 = _mm512_add_epi16(accu_3, _mm512_srli_epi16(res_hi, 8)); + let lut = _mm512_loadu_si512(lut.as_ptr().add(i * 16).cast()); + let res_lo = _mm512_shuffle_epi8(lut, clo); + accu_0 = _mm512_add_epi16(accu_0, res_lo); + accu_1 = _mm512_add_epi16(accu_1, _mm512_srli_epi16(res_lo, 8)); + let res_hi = _mm512_shuffle_epi8(lut, chi); + accu_2 = _mm512_add_epi16(accu_2, res_hi); + accu_3 = _mm512_add_epi16(accu_3, _mm512_srli_epi16(res_hi, 8)); - i += 4; - } - if i + 2 <= width as usize { - let c = _mm256_loadu_si256(codes.as_ptr().add(i * 16).cast()); - - let mask = _mm256_set1_epi8(0xf); - let clo = _mm256_and_si256(c, mask); - let chi = _mm256_and_si256(_mm256_srli_epi16(c, 4), mask); - - let lut = _mm256_loadu_si256(lut.as_ptr().add(i * 16).cast()); - let res_lo = _mm512_zextsi256_si512(_mm256_shuffle_epi8(lut, clo)); - accu_0 = _mm512_add_epi16(accu_0, res_lo); - accu_1 = _mm512_add_epi16(accu_1, _mm512_srli_epi16(res_lo, 8)); - let res_hi = _mm512_zextsi256_si512(_mm256_shuffle_epi8(lut, chi)); - accu_2 = _mm512_add_epi16(accu_2, res_hi); - accu_3 = _mm512_add_epi16(accu_3, _mm512_srli_epi16(res_hi, 8)); - - i += 2; - } - if i < width as usize { - let c = _mm_loadu_si128(codes.as_ptr().add(i * 16).cast()); - - let mask = _mm_set1_epi8(0xf); - let clo = _mm_and_si128(c, mask); - let chi = _mm_and_si128(_mm_srli_epi16(c, 4), mask); - - let lut = _mm_loadu_si128(lut.as_ptr().add(i * 16).cast()); - let res_lo = _mm512_zextsi128_si512(_mm_shuffle_epi8(lut, clo)); - accu_0 = _mm512_add_epi16(accu_0, res_lo); - accu_1 = _mm512_add_epi16(accu_1, _mm512_srli_epi16(res_lo, 8)); - let res_hi = _mm512_zextsi128_si512(_mm_shuffle_epi8(lut, chi)); - accu_2 = _mm512_add_epi16(accu_2, res_hi); - accu_3 = _mm512_add_epi16(accu_3, _mm512_srli_epi16(res_hi, 8)); - - i += 1; - } - debug_assert_eq!(i, width as usize); + i += 4; + } + if i + 2 <= width as usize { + let c = _mm256_loadu_si256(codes.as_ptr().add(i * 16).cast()); + + let mask = _mm256_set1_epi8(0xf); + let clo = _mm256_and_si256(c, mask); + let chi = _mm256_and_si256(_mm256_srli_epi16(c, 4), mask); + + let lut = _mm256_loadu_si256(lut.as_ptr().add(i * 16).cast()); + let res_lo = _mm512_zextsi256_si512(_mm256_shuffle_epi8(lut, clo)); + accu_0 = _mm512_add_epi16(accu_0, res_lo); + accu_1 = _mm512_add_epi16(accu_1, _mm512_srli_epi16(res_lo, 8)); + let res_hi = _mm512_zextsi256_si512(_mm256_shuffle_epi8(lut, chi)); + accu_2 = _mm512_add_epi16(accu_2, res_hi); + accu_3 = _mm512_add_epi16(accu_3, _mm512_srli_epi16(res_hi, 8)); + + i += 2; + } + if i < width as usize { + let c = _mm_loadu_si128(codes.as_ptr().add(i * 16).cast()); + + let mask = _mm_set1_epi8(0xf); + let clo = _mm_and_si128(c, mask); + let chi = _mm_and_si128(_mm_srli_epi16(c, 4), mask); + + let lut = _mm_loadu_si128(lut.as_ptr().add(i * 16).cast()); + let res_lo = _mm512_zextsi128_si512(_mm_shuffle_epi8(lut, clo)); + accu_0 = _mm512_add_epi16(accu_0, res_lo); + accu_1 = _mm512_add_epi16(accu_1, _mm512_srli_epi16(res_lo, 8)); + let res_hi = _mm512_zextsi128_si512(_mm_shuffle_epi8(lut, chi)); + accu_2 = _mm512_add_epi16(accu_2, res_hi); + accu_3 = _mm512_add_epi16(accu_3, _mm512_srli_epi16(res_hi, 8)); + + i += 1; + } + debug_assert_eq!(i, width as usize); - let mut result = [0_u16; 32]; + let mut result = [0_u16; 32]; - accu_0 = _mm512_sub_epi16(accu_0, _mm512_slli_epi16(accu_1, 8)); - _mm256_storeu_si256( - result.as_mut_ptr().add(0).cast(), - combine4x2(accu_0, accu_1), - ); + accu_0 = _mm512_sub_epi16(accu_0, _mm512_slli_epi16(accu_1, 8)); + _mm256_storeu_si256( + result.as_mut_ptr().add(0).cast(), + combine4x2(accu_0, accu_1), + ); - accu_2 = _mm512_sub_epi16(accu_2, _mm512_slli_epi16(accu_3, 8)); - _mm256_storeu_si256( - result.as_mut_ptr().add(16).cast(), - combine4x2(accu_2, accu_3), - ); + accu_2 = _mm512_sub_epi16(accu_2, _mm512_slli_epi16(accu_3, 8)); + _mm256_storeu_si256( + result.as_mut_ptr().add(16).cast(), + combine4x2(accu_2, accu_3), + ); - result + result + } } -} -#[cfg(target_arch = "x86_64")] -#[detect::target_cpu(enable = "v3")] -unsafe fn fast_scan_v3(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { - // bounds checking is not enforced by compiler, so check it manually - assert_eq!(codes.len(), width as usize * 16); - assert_eq!(lut.len(), width as usize * 16); - - unsafe { - use std::arch::x86_64::*; - - #[inline] - #[detect::target_cpu(enable = "v3")] - unsafe fn combine2x2(x0x1: __m256i, y0y1: __m256i) -> __m256i { - unsafe { - let x1y0 = _mm256_permute2f128_si256(x0x1, y0y1, 0x21); - let x0y1 = _mm256_blend_epi32(x0x1, y0y1, 0xf0); - _mm256_add_epi16(x1y0, x0y1) + #[cfg(target_arch = "x86_64")] + #[test] + fn fast_scan_b4_v4_test() { + detect::init(); + if !detect::v4::detect() { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + for _ in 0..200 { + for width in 90..110 { + let codes = (0..16 * width).map(|_| rand::random()).collect::>(); + let lut = (0..16 * width).map(|_| rand::random()).collect::>(); + unsafe { + assert_eq!( + fast_scan_b4_v4(width, &codes, &lut), + fast_scan_b4_fallback(width, &codes, &lut) + ); + } } } + } + + #[cfg(target_arch = "x86_64")] + #[detect::target_cpu(enable = "v3")] + unsafe fn fast_scan_b4_v3(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { + // bounds checking is not enforced by compiler, so check it manually + assert_eq!(codes.len(), width as usize * 16); + assert_eq!(lut.len(), width as usize * 16); + + unsafe { + use std::arch::x86_64::*; + + #[inline] + #[detect::target_cpu(enable = "v3")] + unsafe fn combine2x2(x0x1: __m256i, y0y1: __m256i) -> __m256i { + unsafe { + let x1y0 = _mm256_permute2f128_si256(x0x1, y0y1, 0x21); + let x0y1 = _mm256_blend_epi32(x0x1, y0y1, 0xf0); + _mm256_add_epi16(x1y0, x0y1) + } + } - let mut accu_0 = _mm256_setzero_si256(); - let mut accu_1 = _mm256_setzero_si256(); - let mut accu_2 = _mm256_setzero_si256(); - let mut accu_3 = _mm256_setzero_si256(); + let mut accu_0 = _mm256_setzero_si256(); + let mut accu_1 = _mm256_setzero_si256(); + let mut accu_2 = _mm256_setzero_si256(); + let mut accu_3 = _mm256_setzero_si256(); - let mut i = 0_usize; - while i + 2 <= width as usize { - let c = _mm256_loadu_si256(codes.as_ptr().add(i * 16).cast()); + let mut i = 0_usize; + while i + 2 <= width as usize { + let c = _mm256_loadu_si256(codes.as_ptr().add(i * 16).cast()); - let mask = _mm256_set1_epi8(0xf); - let clo = _mm256_and_si256(c, mask); - let chi = _mm256_and_si256(_mm256_srli_epi16(c, 4), mask); + let mask = _mm256_set1_epi8(0xf); + let clo = _mm256_and_si256(c, mask); + let chi = _mm256_and_si256(_mm256_srli_epi16(c, 4), mask); - let lut = _mm256_loadu_si256(lut.as_ptr().add(i * 16).cast()); - let res_lo = _mm256_shuffle_epi8(lut, clo); - accu_0 = _mm256_add_epi16(accu_0, res_lo); - accu_1 = _mm256_add_epi16(accu_1, _mm256_srli_epi16(res_lo, 8)); - let res_hi = _mm256_shuffle_epi8(lut, chi); - accu_2 = _mm256_add_epi16(accu_2, res_hi); - accu_3 = _mm256_add_epi16(accu_3, _mm256_srli_epi16(res_hi, 8)); + let lut = _mm256_loadu_si256(lut.as_ptr().add(i * 16).cast()); + let res_lo = _mm256_shuffle_epi8(lut, clo); + accu_0 = _mm256_add_epi16(accu_0, res_lo); + accu_1 = _mm256_add_epi16(accu_1, _mm256_srli_epi16(res_lo, 8)); + let res_hi = _mm256_shuffle_epi8(lut, chi); + accu_2 = _mm256_add_epi16(accu_2, res_hi); + accu_3 = _mm256_add_epi16(accu_3, _mm256_srli_epi16(res_hi, 8)); - i += 2; - } - if i < width as usize { - let c = _mm_loadu_si128(codes.as_ptr().add(i * 16).cast()); - - let mask = _mm_set1_epi8(0xf); - let clo = _mm_and_si128(c, mask); - let chi = _mm_and_si128(_mm_srli_epi16(c, 4), mask); - - let lut = _mm_loadu_si128(lut.as_ptr().add(i * 16).cast()); - let res_lo = _mm256_zextsi128_si256(_mm_shuffle_epi8(lut, clo)); - accu_0 = _mm256_add_epi16(accu_0, res_lo); - accu_1 = _mm256_add_epi16(accu_1, _mm256_srli_epi16(res_lo, 8)); - let res_hi = _mm256_zextsi128_si256(_mm_shuffle_epi8(lut, chi)); - accu_2 = _mm256_add_epi16(accu_2, res_hi); - accu_3 = _mm256_add_epi16(accu_3, _mm256_srli_epi16(res_hi, 8)); - - i += 1; - } - debug_assert_eq!(i, width as usize); + i += 2; + } + if i < width as usize { + let c = _mm_loadu_si128(codes.as_ptr().add(i * 16).cast()); + + let mask = _mm_set1_epi8(0xf); + let clo = _mm_and_si128(c, mask); + let chi = _mm_and_si128(_mm_srli_epi16(c, 4), mask); + + let lut = _mm_loadu_si128(lut.as_ptr().add(i * 16).cast()); + let res_lo = _mm256_zextsi128_si256(_mm_shuffle_epi8(lut, clo)); + accu_0 = _mm256_add_epi16(accu_0, res_lo); + accu_1 = _mm256_add_epi16(accu_1, _mm256_srli_epi16(res_lo, 8)); + let res_hi = _mm256_zextsi128_si256(_mm_shuffle_epi8(lut, chi)); + accu_2 = _mm256_add_epi16(accu_2, res_hi); + accu_3 = _mm256_add_epi16(accu_3, _mm256_srli_epi16(res_hi, 8)); + + i += 1; + } + debug_assert_eq!(i, width as usize); - let mut result = [0_u16; 32]; + let mut result = [0_u16; 32]; - accu_0 = _mm256_sub_epi16(accu_0, _mm256_slli_epi16(accu_1, 8)); - _mm256_storeu_si256( - result.as_mut_ptr().add(0).cast(), - combine2x2(accu_0, accu_1), - ); + accu_0 = _mm256_sub_epi16(accu_0, _mm256_slli_epi16(accu_1, 8)); + _mm256_storeu_si256( + result.as_mut_ptr().add(0).cast(), + combine2x2(accu_0, accu_1), + ); - accu_2 = _mm256_sub_epi16(accu_2, _mm256_slli_epi16(accu_3, 8)); - _mm256_storeu_si256( - result.as_mut_ptr().add(16).cast(), - combine2x2(accu_2, accu_3), - ); + accu_2 = _mm256_sub_epi16(accu_2, _mm256_slli_epi16(accu_3, 8)); + _mm256_storeu_si256( + result.as_mut_ptr().add(16).cast(), + combine2x2(accu_2, accu_3), + ); - result + result + } } -} -#[cfg(target_arch = "x86_64")] -#[detect::target_cpu(enable = "v2")] -unsafe fn fast_scan_v2(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { - // bounds checking is not enforced by compiler, so check it manually - assert_eq!(codes.len(), width as usize * 16); - assert_eq!(lut.len(), width as usize * 16); - - unsafe { - use std::arch::x86_64::*; - - let mut accu_0 = _mm_setzero_si128(); - let mut accu_1 = _mm_setzero_si128(); - let mut accu_2 = _mm_setzero_si128(); - let mut accu_3 = _mm_setzero_si128(); - - let mut i = 0_usize; - while i < width as usize { - let c = _mm_loadu_si128(codes.as_ptr().add(i * 16).cast()); - - let mask = _mm_set1_epi8(0xf); - let clo = _mm_and_si128(c, mask); - let chi = _mm_and_si128(_mm_srli_epi16(c, 4), mask); - - let lut = _mm_loadu_si128(lut.as_ptr().add(i * 16).cast()); - let res_lo = _mm_shuffle_epi8(lut, clo); - accu_0 = _mm_add_epi16(accu_0, res_lo); - accu_1 = _mm_add_epi16(accu_1, _mm_srli_epi16(res_lo, 8)); - let res_hi = _mm_shuffle_epi8(lut, chi); - accu_2 = _mm_add_epi16(accu_2, res_hi); - accu_3 = _mm_add_epi16(accu_3, _mm_srli_epi16(res_hi, 8)); - - i += 1; + #[cfg(target_arch = "x86_64")] + #[test] + fn fast_scan_b4_v3_test() { + detect::init(); + if !detect::v3::detect() { + println!("test {} ... skipped (v3)", module_path!()); + return; } - debug_assert_eq!(i, width as usize); + for _ in 0..200 { + for width in 90..110 { + let codes = (0..16 * width).map(|_| rand::random()).collect::>(); + let lut = (0..16 * width).map(|_| rand::random()).collect::>(); + unsafe { + assert_eq!( + fast_scan_b4_v3(width, &codes, &lut), + fast_scan_b4_fallback(width, &codes, &lut) + ); + } + } + } + } + + #[cfg(target_arch = "x86_64")] + #[detect::target_cpu(enable = "v2")] + unsafe fn fast_scan_b4_v2(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { + // bounds checking is not enforced by compiler, so check it manually + assert_eq!(codes.len(), width as usize * 16); + assert_eq!(lut.len(), width as usize * 16); + + unsafe { + use std::arch::x86_64::*; + + let mut accu_0 = _mm_setzero_si128(); + let mut accu_1 = _mm_setzero_si128(); + let mut accu_2 = _mm_setzero_si128(); + let mut accu_3 = _mm_setzero_si128(); + + let mut i = 0_usize; + while i < width as usize { + let c = _mm_loadu_si128(codes.as_ptr().add(i * 16).cast()); + + let mask = _mm_set1_epi8(0xf); + let clo = _mm_and_si128(c, mask); + let chi = _mm_and_si128(_mm_srli_epi16(c, 4), mask); + + let lut = _mm_loadu_si128(lut.as_ptr().add(i * 16).cast()); + let res_lo = _mm_shuffle_epi8(lut, clo); + accu_0 = _mm_add_epi16(accu_0, res_lo); + accu_1 = _mm_add_epi16(accu_1, _mm_srli_epi16(res_lo, 8)); + let res_hi = _mm_shuffle_epi8(lut, chi); + accu_2 = _mm_add_epi16(accu_2, res_hi); + accu_3 = _mm_add_epi16(accu_3, _mm_srli_epi16(res_hi, 8)); + + i += 1; + } + debug_assert_eq!(i, width as usize); - let mut result = [0_u16; 32]; + let mut result = [0_u16; 32]; - accu_0 = _mm_sub_epi16(accu_0, _mm_slli_epi16(accu_1, 8)); - _mm_storeu_si128(result.as_mut_ptr().add(0).cast(), accu_0); - _mm_storeu_si128(result.as_mut_ptr().add(8).cast(), accu_1); + accu_0 = _mm_sub_epi16(accu_0, _mm_slli_epi16(accu_1, 8)); + _mm_storeu_si128(result.as_mut_ptr().add(0).cast(), accu_0); + _mm_storeu_si128(result.as_mut_ptr().add(8).cast(), accu_1); - accu_2 = _mm_sub_epi16(accu_2, _mm_slli_epi16(accu_3, 8)); - _mm_storeu_si128(result.as_mut_ptr().add(16).cast(), accu_2); - _mm_storeu_si128(result.as_mut_ptr().add(24).cast(), accu_3); + accu_2 = _mm_sub_epi16(accu_2, _mm_slli_epi16(accu_3, 8)); + _mm_storeu_si128(result.as_mut_ptr().add(16).cast(), accu_2); + _mm_storeu_si128(result.as_mut_ptr().add(24).cast(), accu_3); - result + result + } } -} -#[cfg(target_arch = "x86_64")] -#[test] -fn test_v4_v3() { - detect::init(); - if !detect::v4::detect() || !detect::v3::detect() { - println!("test {} ... skipped (v4, v3)", module_path!()); - return; - } - for _ in 0..200 { - for width in 90..110 { - let codes = (0..16 * width).map(|_| rand::random()).collect::>(); - let lut = (0..16 * width).map(|_| rand::random()).collect::>(); - unsafe { - assert_eq!( - fast_scan_v4(width, &codes, &lut), - fast_scan_v3(width, &codes, &lut) - ); + #[cfg(target_arch = "x86_64")] + #[test] + fn fast_scan_b4_v2_test() { + detect::init(); + if !detect::v2::detect() { + println!("test {} ... skipped (v2)", module_path!()); + return; + } + for _ in 0..200 { + for width in 90..110 { + let codes = (0..16 * width).map(|_| rand::random()).collect::>(); + let lut = (0..16 * width).map(|_| rand::random()).collect::>(); + unsafe { + assert_eq!( + fast_scan_b4_v2(width, &codes, &lut), + fast_scan_b4_fallback(width, &codes, &lut) + ); + } } } } -} -#[cfg(target_arch = "x86_64")] -#[test] -fn test_v3_v2() { - detect::init(); - if !detect::v3::detect() || !detect::v2::detect() { - println!("test {} ... skipped (v3, v2)", module_path!()); - return; - } - for _ in 0..200 { - for width in 90..110 { - let codes = (0..16 * width).map(|_| rand::random()).collect::>(); - let lut = (0..16 * width).map(|_| rand::random()).collect::>(); - unsafe { - assert_eq!( - fast_scan_v3(width, &codes, &lut), - fast_scan_v2(width, &codes, &lut) - ); - } + #[detect::multiversion(v4 = import, v3 = import, v2 = import, neon, fallback = export)] + pub fn fast_scan_b4(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { + let width = width as usize; + + assert_eq!(codes.len(), width * 16); + assert_eq!(lut.len(), width * 16); + + use std::array::from_fn; + use std::ops::BitAnd; + + fn load(slice: &[T]) -> [T; N] { + from_fn(|i| slice[i]) + } + fn unary(op: impl Fn(T) -> U, a: [T; N]) -> [U; N] { + from_fn(|i| op(a[i])) + } + fn binary(op: impl Fn(T, T) -> T, a: [T; N], b: [T; N]) -> [T; N] { + from_fn(|i| op(a[i], b[i])) + } + fn shuffle(a: [T; N], b: [u8; N]) -> [T; N] { + from_fn(|i| a[b[i] as usize]) } + fn cast(x: [u8; 16]) -> [u16; 8] { + from_fn(|i| u16::from_le_bytes([x[i << 1 | 0], x[i << 1 | 1]])) + } + fn setr(x: [[T; 8]; 4]) -> [T; 32] { + from_fn(|i| x[i >> 3][i & 7]) + } + + let mut a_0 = [0u16; 8]; + let mut a_1 = [0u16; 8]; + let mut a_2 = [0u16; 8]; + let mut a_3 = [0u16; 8]; + + for i in 0..width { + let c = load(&codes[16 * i as usize..]); + + let mask = [0xfu8; 16]; + let clo = binary(u8::bitand, c, mask); + let chi = binary(u8::bitand, unary(|x| x >> 4, c), mask); + + let lut = load(&lut[16 * i as usize..]); + let res_lo = cast(shuffle(lut, clo)); + a_0 = binary(u16::wrapping_add, a_0, res_lo); + a_1 = binary(u16::wrapping_add, a_1, unary(|x| x >> 8, res_lo)); + let res_hi = cast(shuffle(lut, chi)); + a_2 = binary(u16::wrapping_add, a_2, res_hi); + a_3 = binary(u16::wrapping_add, a_3, unary(|x| x >> 8, res_hi)); + } + + a_0 = binary(u16::wrapping_sub, a_0, unary(|x| x.wrapping_shl(8), a_1)); + a_2 = binary(u16::wrapping_sub, a_2, unary(|x| x.wrapping_shl(8), a_3)); + + setr([a_0, a_1, a_2, a_3]) } } + +#[inline(always)] +pub fn fast_scan_b4(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { + fast_scan_b4::fast_scan_b4(width, codes, lut) +} diff --git a/crates/quantization/src/product/mod.rs b/crates/quantization/src/product/mod.rs index 3955964b7..8f447dbc0 100644 --- a/crates/quantization/src/product/mod.rs +++ b/crates/quantization/src/product/mod.rs @@ -140,8 +140,8 @@ impl ProductQuantizer { let dims = self.dims; let ratio = self.ratio; let width = dims.div_ceil(ratio); - if fast_scan && self.bits == 4 && crate::fast_scan::b4::is_supported() { - use crate::fast_scan::b4::{fast_scan, BLOCK_SIZE}; + if fast_scan && self.bits == 4 { + use crate::fast_scan::b4::{fast_scan_b4, BLOCK_SIZE}; let (k, b, lut) = O::fscan_preprocess(preprocessed); let s = rhs.start.next_multiple_of(BLOCK_SIZE); let e = (rhs.end + 1 - BLOCK_SIZE).next_multiple_of(BLOCK_SIZE); @@ -150,7 +150,7 @@ impl ProductQuantizer { let bytes = width as usize * 16; let start = (i / BLOCK_SIZE) as usize * bytes; let end = start + bytes; - let res = fast_scan(width, &packed_codes[start..end], &lut); + let res = fast_scan_b4(width, &packed_codes[start..end], &lut); let r = res.map(|x| O::fscan_process(width, k, b, x)); heap.extend({ (rhs.start..s).map(|u| (Reverse(r[(u - i) as usize]), AlwaysEqual(u))) @@ -160,7 +160,7 @@ impl ProductQuantizer { let bytes = width as usize * 16; let start = (i / BLOCK_SIZE) as usize * bytes; let end = start + bytes; - let res = fast_scan(width, &packed_codes[start..end], &lut); + let res = fast_scan_b4(width, &packed_codes[start..end], &lut); let r = res.map(|x| O::fscan_process(width, k, b, x)); heap.extend({ (i..i + BLOCK_SIZE).map(|u| (Reverse(r[(u - i) as usize]), AlwaysEqual(u))) @@ -171,7 +171,7 @@ impl ProductQuantizer { let bytes = width as usize * 16; let start = (i / BLOCK_SIZE) as usize * bytes; let end = start + bytes; - let res = fast_scan(width, &packed_codes[start..end], &lut); + let res = fast_scan_b4(width, &packed_codes[start..end], &lut); let r = res.map(|x| O::fscan_process(width, k, b, x)); heap.extend({ (e..rhs.end).map(|u| (Reverse(r[(u - i) as usize]), AlwaysEqual(u))) diff --git a/crates/quantization/src/scalar/mod.rs b/crates/quantization/src/scalar/mod.rs index 83ade933d..1dad72f0e 100644 --- a/crates/quantization/src/scalar/mod.rs +++ b/crates/quantization/src/scalar/mod.rs @@ -125,8 +125,8 @@ impl ScalarQuantizer { ) { let dims = self.dims; let width = dims; - if fast_scan && self.bits == 4 && crate::fast_scan::b4::is_supported() { - use crate::fast_scan::b4::{fast_scan, BLOCK_SIZE}; + if fast_scan && self.bits == 4 { + use crate::fast_scan::b4::{fast_scan_b4, BLOCK_SIZE}; let (k, b, lut) = O::fscan_preprocess(preprocessed); let s = rhs.start.next_multiple_of(BLOCK_SIZE); let e = (rhs.end + 1 - BLOCK_SIZE).next_multiple_of(BLOCK_SIZE); @@ -135,7 +135,7 @@ impl ScalarQuantizer { let bytes = width as usize * 16; let start = (i / BLOCK_SIZE) as usize * bytes; let end = start + bytes; - let res = fast_scan(width, &packed_codes[start..end], &lut); + let res = fast_scan_b4(width, &packed_codes[start..end], &lut); let r = res.map(|x| O::fscan_process(width, k, b, x)); heap.extend({ (rhs.start..s).map(|u| (Reverse(r[(u - i) as usize]), AlwaysEqual(u))) @@ -145,7 +145,7 @@ impl ScalarQuantizer { let bytes = width as usize * 16; let start = (i / BLOCK_SIZE) as usize * bytes; let end = start + bytes; - let res = fast_scan(width, &packed_codes[start..end], &lut); + let res = fast_scan_b4(width, &packed_codes[start..end], &lut); let r = res.map(|x| O::fscan_process(width, k, b, x)); heap.extend({ (i..i + BLOCK_SIZE).map(|u| (Reverse(r[(u - i) as usize]), AlwaysEqual(u))) @@ -156,7 +156,7 @@ impl ScalarQuantizer { let bytes = width as usize * 16; let start = (i / BLOCK_SIZE) as usize * bytes; let end = start + bytes; - let res = fast_scan(width, &packed_codes[start..end], &lut); + let res = fast_scan_b4(width, &packed_codes[start..end], &lut); let r = res.map(|x| O::fscan_process(width, k, b, x)); heap.extend({ (e..rhs.end).map(|u| (Reverse(r[(u - i) as usize]), AlwaysEqual(u))) diff --git a/crates/rabitq/src/quant/quantizer.rs b/crates/rabitq/src/quant/quantizer.rs index 9855e7402..64f48b8a4 100644 --- a/crates/rabitq/src/quant/quantizer.rs +++ b/crates/rabitq/src/quant/quantizer.rs @@ -103,8 +103,8 @@ impl RabitqQuantizer { epsilon: f32, fast_scan: bool, ) { - if fast_scan && quantization::fast_scan::b4::is_supported() { - use quantization::fast_scan::b4::{fast_scan, BLOCK_SIZE}; + if fast_scan { + use quantization::fast_scan::b4::{fast_scan_b4, BLOCK_SIZE}; let s = rhs.start.next_multiple_of(BLOCK_SIZE); let e = (rhs.end + 1 - BLOCK_SIZE).next_multiple_of(BLOCK_SIZE); let lut = O::fscan_preprocess(p1); @@ -114,7 +114,7 @@ impl RabitqQuantizer { let bytes = (t * 16) as usize; let start = (i / BLOCK_SIZE) as usize * bytes; let end = start + bytes; - let res = fast_scan(t, &packed_codes[start..end], &lut); + let res = fast_scan_b4(t, &packed_codes[start..end], &lut); heap.extend({ (rhs.start..s).map(|u| { ( @@ -136,7 +136,7 @@ impl RabitqQuantizer { let bytes = (t * 16) as usize; let start = (i / BLOCK_SIZE) as usize * bytes; let end = start + bytes; - let res = fast_scan(t, &packed_codes[start..end], &lut); + let res = fast_scan_b4(t, &packed_codes[start..end], &lut); heap.extend({ (i..i + BLOCK_SIZE).map(|u| { ( @@ -159,7 +159,7 @@ impl RabitqQuantizer { let bytes = (t * 16) as usize; let start = (i / BLOCK_SIZE) as usize * bytes; let end = start + bytes; - let res = fast_scan(t, &packed_codes[start..end], &lut); + let res = fast_scan_b4(t, &packed_codes[start..end], &lut); heap.extend({ (e..rhs.end).map(|u| { (