From b24045786cd24156a67f2fbb181149f2676e9d78 Mon Sep 17 00:00:00 2001 From: usamoi Date: Sat, 14 Sep 2024 23:16:36 +0800 Subject: [PATCH] refactor: rabitq Signed-off-by: usamoi --- Cargo.lock | 35 +- crates/base/src/index.rs | 75 +- crates/base/src/scalar/f16.rs | 14 + crates/base/src/scalar/f32.rs | 14 + crates/base/src/scalar/impossible.rs | 13 + crates/base/src/scalar/mod.rs | 3 + crates/base/src/vector/vect.rs | 5 + crates/cli/src/args.rs | 9 +- crates/common/src/vec2.rs | 3 + crates/flat/src/lib.rs | 5 +- crates/hnsw/src/lib.rs | 20 +- crates/index/src/segment/sealed.rs | 17 +- crates/indexing/Cargo.toml | 1 - crates/indexing/src/lib.rs | 6 +- crates/indexing/src/sealed.rs | 59 +- crates/ivf/src/lib.rs | 51 +- crates/ivf/src/operator.rs | 76 +- crates/k_means/src/elkan.rs | 264 ------- crates/k_means/src/lib.rs | 27 +- crates/quantization/Cargo.toml | 3 + crates/quantization/src/lib.rs | 104 ++- crates/quantization/src/product.rs | 44 +- crates/quantization/src/quantizer.rs | 6 +- crates/quantization/src/rabitq.rs | 725 ++++++++++++++++++ .../src/reranker}/error.rs | 0 crates/quantization/src/reranker/mod.rs | 1 + crates/quantization/src/scalar.rs | 41 +- crates/quantization/src/trivial.rs | 8 +- crates/rabitq/Cargo.toml | 24 - crates/rabitq/src/lib.rs | 255 ------ crates/rabitq/src/operator.rs | 694 ----------------- crates/rabitq/src/quant/mod.rs | 3 - crates/rabitq/src/quant/quantization.rs | 192 ----- crates/rabitq/src/quant/quantizer.rs | 247 ------ crates/stoppable_rayon/src/lib.rs | 25 +- src/gucs/executing.rs | 66 +- 36 files changed, 1118 insertions(+), 2017 deletions(-) delete mode 100644 crates/k_means/src/elkan.rs create mode 100644 crates/quantization/src/rabitq.rs rename crates/{rabitq/src/quant => quantization/src/reranker}/error.rs (100%) delete mode 100644 crates/rabitq/Cargo.toml delete mode 100644 crates/rabitq/src/lib.rs delete mode 100644 crates/rabitq/src/operator.rs delete mode 100644 crates/rabitq/src/quant/mod.rs delete mode 100644 crates/rabitq/src/quant/quantization.rs delete mode 100644 crates/rabitq/src/quant/quantizer.rs diff --git a/Cargo.lock b/Cargo.lock index 5bc28158f..10a31d200 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -102,7 +102,6 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" dependencies = [ - "num-complex", "num-traits", ] @@ -550,9 +549,9 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.16.3" +version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "102087e286b4677862ea56cf8fc58bb2cdfa8725c40ffb80fe3a008eb7f2fc83" +checksum = "94bbb0ad554ad961ddc5da507a12a29b14e4ae5bda06b19f575a3e6079d2e2ae" [[package]] name = "byteorder" @@ -1522,7 +1521,6 @@ dependencies = [ "inverted", "ivf", "quantization", - "rabitq", "thiserror", ] @@ -1876,8 +1874,6 @@ dependencies = [ "num-complex", "num-rational", "num-traits", - "rand", - "rand_distr", "simba", "typenum", ] @@ -2454,7 +2450,10 @@ dependencies = [ "detect", "k_means", "log", + "nalgebra", "rand", + "rand_chacha", + "rand_distr", "serde", "serde_json", "stoppable_rayon", @@ -2517,26 +2516,6 @@ dependencies = [ "proc-macro2", ] -[[package]] -name = "rabitq" -version = "0.0.0" -dependencies = [ - "base", - "common", - "detect", - "half 2.4.1", - "k_means", - "log", - "nalgebra", - "quantization", - "rand", - "rand_distr", - "serde", - "serde_json", - "stoppable_rayon", - "storage", -] - [[package]] name = "radium" version = "0.7.0" @@ -3619,9 +3598,9 @@ dependencies = [ [[package]] name = "wide" -version = "0.7.26" +version = "0.7.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "901e8597c777fa042e9e245bd56c0dc4418c5db3f845b6ff94fbac732c6a0692" +checksum = "b828f995bf1e9622031f8009f8481a85406ce1f4d4588ff746d872043e855690" dependencies = [ "bytemuck", "safe_arch", diff --git a/crates/base/src/index.rs b/crates/base/src/index.rs index 77765b43d..13efef96c 100644 --- a/crates/base/src/index.rs +++ b/crates/base/src/index.rs @@ -110,10 +110,14 @@ impl IndexOptions { ) -> Result<(), ValidationError> { match quantization { None => Ok(()), - Some(QuantizationOptions::Scalar(_) | QuantizationOptions::Product(_)) => { + Some( + QuantizationOptions::Scalar(_) + | QuantizationOptions::Product(_) + | QuantizationOptions::Rabitq(_), + ) => { if !matches!(self.vector.v, VectorKind::Vecf32 | VectorKind::Vecf16) { return Err(ValidationError::new( - "scalar quantization or product quantization is not support for vectors that are not dense vectors", + "quantization is not support for vectors that are not dense vectors", )); } Ok(()) @@ -148,18 +152,6 @@ impl IndexOptions { )); } } - IndexingOptions::Rabitq(_) => { - if !matches!(self.vector.d, DistanceKind::L2 | DistanceKind::Dot) { - return Err(ValidationError::new( - "rabitq is not support for distance that is not l2 or dot", - )); - } - if !matches!(self.vector.v, VectorKind::Vecf32) { - return Err(ValidationError::new( - "rabitq is not support for vectors that are not vector", - )); - } - } } Ok(()) } @@ -293,7 +285,6 @@ pub enum IndexingOptions { Ivf(IvfIndexingOptions), Hnsw(HnswIndexingOptions), InvertedIndex(InvertedIndexingOptions), - Rabitq(RabitqIndexingOptions), } impl IndexingOptions { @@ -315,12 +306,6 @@ impl IndexingOptions { }; x } - pub fn unwrap_rabitq(self) -> RabitqIndexingOptions { - let IndexingOptions::Rabitq(x) = self else { - unreachable!() - }; - x - } } impl Default for IndexingOptions { @@ -336,7 +321,6 @@ impl Validate for IndexingOptions { Self::Ivf(x) => x.validate(), Self::Hnsw(x) => x.validate(), Self::InvertedIndex(x) => x.validate(), - Self::Rabitq(x) => x.validate(), } } } @@ -480,6 +464,7 @@ impl Default for RabitqIndexingOptions { pub enum QuantizationOptions { Scalar(ScalarQuantizationOptions), Product(ProductQuantizationOptions), + Rabitq(RabitqQuantizationOptions), } impl Validate for QuantizationOptions { @@ -487,6 +472,7 @@ impl Validate for QuantizationOptions { match self { Self::Scalar(x) => x.validate(), Self::Product(x) => x.validate(), + Self::Rabitq(x) => x.validate(), } } } @@ -554,6 +540,18 @@ impl Default for ProductQuantizationOptions { } } +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +pub struct RabitqQuantizationOptions {} + +impl RabitqQuantizationOptions {} + +impl Default for RabitqQuantizationOptions { + fn default() -> Self { + Self {} + } +} + #[derive(Debug, Clone, Serialize, Deserialize, Validate, Alter)] #[serde(deny_unknown_fields)] pub struct SearchOptions { @@ -567,23 +565,14 @@ pub struct SearchOptions { pub pq_rerank_size: u32, #[serde(default = "SearchOptions::default_pq_fast_scan")] pub pq_fast_scan: bool, + #[serde(default = "SearchOptions::default_rq_fast_scan")] + pub rq_fast_scan: bool, #[serde(default = "SearchOptions::default_ivf_nprobe")] #[validate(range(min = 1, max = 65535))] pub ivf_nprobe: u32, #[serde(default = "SearchOptions::default_hnsw_ef_search")] #[validate(range(min = 1, max = 65535))] pub hnsw_ef_search: u32, - #[serde(default = "SearchOptions::default_rabitq_nprobe")] - #[validate(range(min = 1, max = 65535))] - pub rabitq_nprobe: u32, - #[serde(default = "SearchOptions::default_rabitq_epsilon")] - #[validate(range(min = 1.0, max = 4.0))] - pub rabitq_epsilon: f32, - #[serde(default = "SearchOptions::default_rabitq_fast_scan")] - pub rabitq_fast_scan: bool, - #[serde(default = "SearchOptions::default_diskann_ef_search")] - #[validate(range(min = 1, max = 65535))] - pub diskann_ef_search: u32, } impl SearchOptions { @@ -599,24 +588,15 @@ impl SearchOptions { pub const fn default_pq_fast_scan() -> bool { false } + pub const fn default_rq_fast_scan() -> bool { + true + } pub const fn default_ivf_nprobe() -> u32 { 10 } pub const fn default_hnsw_ef_search() -> u32 { 100 } - pub const fn default_rabitq_nprobe() -> u32 { - 10 - } - pub const fn default_rabitq_epsilon() -> f32 { - 1.9 - } - pub const fn default_rabitq_fast_scan() -> bool { - true - } - pub const fn default_diskann_ef_search() -> u32 { - 100 - } } impl Default for SearchOptions { @@ -626,12 +606,9 @@ impl Default for SearchOptions { sq_fast_scan: Self::default_sq_fast_scan(), pq_rerank_size: Self::default_pq_rerank_size(), pq_fast_scan: Self::default_pq_fast_scan(), + rq_fast_scan: Self::default_rq_fast_scan(), ivf_nprobe: Self::default_ivf_nprobe(), hnsw_ef_search: Self::default_hnsw_ef_search(), - rabitq_nprobe: Self::default_rabitq_nprobe(), - rabitq_epsilon: Self::default_rabitq_epsilon(), - rabitq_fast_scan: Self::default_rabitq_fast_scan(), - diskann_ef_search: Self::default_diskann_ef_search(), } } } diff --git a/crates/base/src/scalar/f16.rs b/crates/base/src/scalar/f16.rs index fcc0f305b..9dafac923 100644 --- a/crates/base/src/scalar/f16.rs +++ b/crates/base/src/scalar/f16.rs @@ -37,6 +37,16 @@ impl ScalarLike for f16 { lhs * rhs } + #[inline(always)] + fn scalar_is_sign_positive(self) -> bool { + self.is_sign_positive() + } + + #[inline(always)] + fn scalar_is_sign_negative(self) -> bool { + self.is_sign_negative() + } + #[inline(always)] fn from_f32(x: f32) -> Self { f16::from_f32(x) @@ -236,6 +246,10 @@ impl ScalarLike for f16 { r } + fn vector_to_f32_borrowed(this: &[Self]) -> impl AsRef<[f32]> { + Self::vector_to_f32(this) + } + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn kmeans_helper(this: &mut [f16], x: f32, y: f32) { let x = f16::from_f32(x); diff --git a/crates/base/src/scalar/f32.rs b/crates/base/src/scalar/f32.rs index 5b318ff40..94783a457 100644 --- a/crates/base/src/scalar/f32.rs +++ b/crates/base/src/scalar/f32.rs @@ -36,6 +36,16 @@ impl ScalarLike for f32 { lhs * rhs } + #[inline(always)] + fn scalar_is_sign_positive(self) -> bool { + self.is_sign_positive() + } + + #[inline(always)] + fn scalar_is_sign_negative(self) -> bool { + self.is_sign_negative() + } + #[inline(always)] fn from_f32(x: f32) -> Self { x @@ -187,6 +197,10 @@ impl ScalarLike for f32 { this.to_vec() } + fn vector_to_f32_borrowed(this: &[Self]) -> impl AsRef<[f32]> { + this + } + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn kmeans_helper(this: &mut [f32], x: f32, y: f32) { let n = this.len(); diff --git a/crates/base/src/scalar/impossible.rs b/crates/base/src/scalar/impossible.rs index 8d23b789b..a241d2296 100644 --- a/crates/base/src/scalar/impossible.rs +++ b/crates/base/src/scalar/impossible.rs @@ -39,6 +39,14 @@ impl ScalarLike for Impossible { unimplemented!() } + fn scalar_is_sign_positive(self) -> bool { + unimplemented!() + } + + fn scalar_is_sign_negative(self) -> bool { + unimplemented!() + } + fn from_f32(_: f32) -> Self { unimplemented!() } @@ -101,6 +109,11 @@ impl ScalarLike for Impossible { unimplemented!() } + #[allow(unreachable_code)] + fn vector_to_f32_borrowed(_: &[Self]) -> impl AsRef<[f32]> { + unimplemented!() as Vec + } + fn vector_add(_lhs: &[Self], _rhs: &[Self]) -> Vec { unimplemented!() } diff --git a/crates/base/src/scalar/mod.rs b/crates/base/src/scalar/mod.rs index 4e41a43bf..2e22591f2 100644 --- a/crates/base/src/scalar/mod.rs +++ b/crates/base/src/scalar/mod.rs @@ -24,6 +24,8 @@ pub trait ScalarLike: fn scalar_add(lhs: Self, rhs: Self) -> Self; fn scalar_sub(lhs: Self, rhs: Self) -> Self; fn scalar_mul(lhs: Self, rhs: Self) -> Self; + fn scalar_is_sign_positive(self) -> bool; + fn scalar_is_sign_negative(self) -> bool; fn from_f32(x: f32) -> Self; fn to_f32(self) -> f32; @@ -42,6 +44,7 @@ pub trait ScalarLike: fn vector_from_f32(this: &[f32]) -> Vec; fn vector_to_f32(this: &[Self]) -> Vec; + fn vector_to_f32_borrowed(this: &[Self]) -> impl AsRef<[f32]>; fn vector_add(lhs: &[Self], rhs: &[Self]) -> Vec; fn vector_add_inplace(lhs: &mut [Self], rhs: &[Self]); fn vector_sub(lhs: &[Self], rhs: &[Self]) -> Vec; diff --git a/crates/base/src/vector/vect.rs b/crates/base/src/vector/vect.rs index 4f7db360d..7e6c08473 100644 --- a/crates/base/src/vector/vect.rs +++ b/crates/base/src/vector/vect.rs @@ -39,6 +39,11 @@ impl VectOwned { pub fn slice_mut(&mut self) -> &mut [S] { self.0.as_mut_slice() } + + #[inline(always)] + pub fn into_vec(self) -> Vec { + self.0 + } } impl VectorOwned for VectOwned { diff --git a/crates/cli/src/args.rs b/crates/cli/src/args.rs index 1004995f6..a24e29ff0 100644 --- a/crates/cli/src/args.rs +++ b/crates/cli/src/args.rs @@ -133,14 +133,11 @@ impl QueryArguments { SearchOptions { sq_rerank_size: 0, pq_rerank_size: 0, - hnsw_ef_search: self.ef, - ivf_nprobe: self.probe, - diskann_ef_search: 100, sq_fast_scan: false, pq_fast_scan: false, - rabitq_epsilon: 1.9, - rabitq_fast_scan: true, - rabitq_nprobe: self.probe, + rq_fast_scan: true, + hnsw_ef_search: self.ef, + ivf_nprobe: self.probe, } } } diff --git a/crates/common/src/vec2.rs b/crates/common/src/vec2.rs index 60d245310..2ad7f1251 100644 --- a/crates/common/src/vec2.rs +++ b/crates/common/src/vec2.rs @@ -32,6 +32,9 @@ impl Vec2 { } impl Vec2 { + pub fn shape(&self) -> (usize, usize) { + self.shape + } pub fn shape_0(&self) -> usize { self.shape.0 } diff --git a/crates/flat/src/lib.rs b/crates/flat/src/lib.rs index 1b1424cb1..763bcbdec 100644 --- a/crates/flat/src/lib.rs +++ b/crates/flat/src/lib.rs @@ -5,6 +5,7 @@ use base::index::*; use base::operator::*; use base::search::*; use base::vector::VectorBorrowed; +use base::vector::VectorOwned; use common::mmap_array::MmapArray; use common::remap::RemappedCollection; use quantization::quantizer::Quantizer; @@ -44,7 +45,9 @@ impl> Flat { opts: &'a SearchOptions, ) -> Box + 'a> { let mut heap = Q::flat_rerank_start(); - let lut = self.quantization.flat_rerank_preprocess(vector, opts); + let lut = self + .quantization + .flat_rerank_preprocess(self.quantization.project(vector).as_borrowed(), opts); self.quantization .flat_rerank_continue(&lut, 0..self.storage.len(), &mut heap); let mut reranker = self.quantization.flat_rerank_break( diff --git a/crates/hnsw/src/lib.rs b/crates/hnsw/src/lib.rs index 7ac72337f..4ca38d280 100644 --- a/crates/hnsw/src/lib.rs +++ b/crates/hnsw/src/lib.rs @@ -7,6 +7,7 @@ use base::index::*; use base::operator::*; use base::search::*; use base::vector::VectorBorrowed; +use base::vector::VectorOwned; use common::json::Json; use common::mmap_array::MmapArray; use common::remap::RemappedCollection; @@ -71,16 +72,15 @@ impl> Hnsw { let Some(s) = self.s else { return Box::new(std::iter::empty()); }; - let s = { - let processed = self.quantization.preprocess(vector); - fast_search( - |u| self.quantization.process(&self.storage, &processed, u), - |x, i| hyper_outs(self, x, i), - 1..=hierarchy_for_a_vertex(*self.m, s) - 1, - s, - ) - }; - let reranker = self.quantization.graph_rerank(vector, move |u| { + let projected_vector = self.quantization.project(vector); + let lut = self.quantization.preprocess(projected_vector.as_borrowed()); + let s = fast_search( + |u| self.quantization.process(&self.storage, &lut, u), + |x, i| hyper_outs(self, x, i), + 1..=hierarchy_for_a_vertex(*self.m, s) - 1, + s, + ); + let reranker = self.quantization.graph_rerank(lut, move |u| { ( O::distance(self.storage.vector(u), vector), (base_outs(self, u), ()), diff --git a/crates/index/src/segment/sealed.rs b/crates/index/src/segment/sealed.rs index 21bb1eb17..de8d91631 100644 --- a/crates/index/src/segment/sealed.rs +++ b/crates/index/src/segment/sealed.rs @@ -117,22 +117,7 @@ impl SealedSegment { } pub fn indexing(&self) -> &dyn Any { - match &self.indexing { - SealedIndexing::Flat(x) => x, - SealedIndexing::FlatPq(x) => x, - SealedIndexing::FlatSq(x) => x, - - SealedIndexing::Ivf(x) => x, - SealedIndexing::IvfPq(x) => x, - SealedIndexing::IvfSq(x) => x, - - SealedIndexing::Hnsw(x) => x, - SealedIndexing::HnswPq(x) => x, - SealedIndexing::HnswSq(x) => x, - - SealedIndexing::InvertedIndex(x) => x, - SealedIndexing::Rabitq(x) => x, - } + self.indexing.as_any() } } diff --git a/crates/indexing/Cargo.toml b/crates/indexing/Cargo.toml index 4bbec8a1d..0d5ed0976 100644 --- a/crates/indexing/Cargo.toml +++ b/crates/indexing/Cargo.toml @@ -14,7 +14,6 @@ hnsw = { path = "../hnsw" } inverted = { path = "../inverted" } ivf = { path = "../ivf" } quantization = { path = "../quantization" } -rabitq = { path = "../rabitq" } [lints] workspace = true diff --git a/crates/indexing/src/lib.rs b/crates/indexing/src/lib.rs index 0ca325662..19af8cd36 100644 --- a/crates/indexing/src/lib.rs +++ b/crates/indexing/src/lib.rs @@ -1,5 +1,6 @@ pub mod sealed; +use quantization::rabitq::OperatorRabitqQuantization; pub use sealed::SealedIndexing; use base::operator::Operator; @@ -7,16 +8,15 @@ use inverted::operator::OperatorInvertedIndex; use ivf::operator::OperatorIvf; use quantization::product::OperatorProductQuantization; use quantization::scalar::OperatorScalarQuantization; -use rabitq::operator::OperatorRabitq; pub trait OperatorIndexing where Self: Operator, Self: OperatorIvf, Self: OperatorInvertedIndex, - Self: OperatorRabitq, Self: OperatorScalarQuantization, Self: OperatorProductQuantization, + Self: OperatorRabitqQuantization, { } @@ -25,8 +25,8 @@ where Self: Operator, Self: OperatorIvf, Self: OperatorInvertedIndex, - Self: OperatorRabitq, Self: OperatorScalarQuantization, Self: OperatorProductQuantization, + Self: OperatorRabitqQuantization, { } diff --git a/crates/indexing/src/sealed.rs b/crates/indexing/src/sealed.rs index 23fda96c2..a01729eb1 100644 --- a/crates/indexing/src/sealed.rs +++ b/crates/indexing/src/sealed.rs @@ -7,23 +7,26 @@ use hnsw::Hnsw; use inverted::InvertedIndex; use ivf::Ivf; use quantization::product::ProductQuantizer; +use quantization::rabitq::RabitqQuantizer; use quantization::scalar::ScalarQuantizer; use quantization::trivial::TrivialQuantizer; -use rabitq::Rabitq; +use std::any::Any; use std::path::Path; pub enum SealedIndexing { Flat(Flat>), FlatSq(Flat>), FlatPq(Flat>), + FlatRq(Flat>), Ivf(Ivf>), IvfSq(Ivf>), IvfPq(Ivf>), + IvfRq(Ivf>), Hnsw(Hnsw>), HnswSq(Hnsw>), HnswPq(Hnsw>), + HnswRq(Hnsw>), InvertedIndex(InvertedIndex), - Rabitq(Rabitq), } impl SealedIndexing { @@ -43,6 +46,9 @@ impl SealedIndexing { Some(QuantizationOptions::Product(_)) => { Self::FlatPq(Flat::create(path, options, source)) } + Some(QuantizationOptions::Rabitq(_)) => { + Self::FlatRq(Flat::create(path, options, source)) + } }, IndexingOptions::Ivf(IvfIndexingOptions { ref quantization, .. @@ -54,6 +60,9 @@ impl SealedIndexing { Some(QuantizationOptions::Product(_)) => { Self::IvfPq(Ivf::create(path, options, source)) } + Some(QuantizationOptions::Rabitq(_)) => { + Self::IvfRq(Ivf::create(path, options, source)) + } }, IndexingOptions::Hnsw(HnswIndexingOptions { ref quantization, .. @@ -65,11 +74,13 @@ impl SealedIndexing { Some(QuantizationOptions::Product(_)) => { Self::HnswPq(Hnsw::create(path, options, source)) } + Some(QuantizationOptions::Rabitq(_)) => { + Self::HnswRq(Hnsw::create(path, options, source)) + } }, IndexingOptions::InvertedIndex(_) => { Self::InvertedIndex(InvertedIndex::create(path, options, source)) } - IndexingOptions::Rabitq(_) => Self::Rabitq(Rabitq::create(path, options, source)), } } @@ -81,6 +92,7 @@ impl SealedIndexing { None => Self::Flat(Flat::open(path)), Some(QuantizationOptions::Scalar(_)) => Self::FlatSq(Flat::open(path)), Some(QuantizationOptions::Product(_)) => Self::FlatPq(Flat::open(path)), + Some(QuantizationOptions::Rabitq(_)) => Self::FlatRq(Flat::open(path)), }, IndexingOptions::Ivf(IvfIndexingOptions { ref quantization, .. @@ -88,6 +100,7 @@ impl SealedIndexing { None => Self::Ivf(Ivf::open(path)), Some(QuantizationOptions::Scalar(_)) => Self::IvfSq(Ivf::open(path)), Some(QuantizationOptions::Product(_)) => Self::IvfPq(Ivf::open(path)), + Some(QuantizationOptions::Rabitq(_)) => Self::IvfRq(Ivf::open(path)), }, IndexingOptions::Hnsw(HnswIndexingOptions { ref quantization, .. @@ -95,9 +108,9 @@ impl SealedIndexing { None => Self::Hnsw(Hnsw::open(path)), Some(QuantizationOptions::Scalar(_)) => Self::HnswSq(Hnsw::open(path)), Some(QuantizationOptions::Product(_)) => Self::HnswPq(Hnsw::open(path)), + Some(QuantizationOptions::Rabitq(_)) => Self::HnswRq(Hnsw::open(path)), }, IndexingOptions::InvertedIndex(_) => Self::InvertedIndex(InvertedIndex::open(path)), - IndexingOptions::Rabitq(_) => Self::Rabitq(Rabitq::open(path)), } } @@ -110,14 +123,34 @@ impl SealedIndexing { SealedIndexing::Flat(x) => x.vbase(vector, opts), SealedIndexing::FlatPq(x) => x.vbase(vector, opts), SealedIndexing::FlatSq(x) => x.vbase(vector, opts), + SealedIndexing::FlatRq(x) => x.vbase(vector, opts), SealedIndexing::Ivf(x) => x.vbase(vector, opts), SealedIndexing::IvfPq(x) => x.vbase(vector, opts), SealedIndexing::IvfSq(x) => x.vbase(vector, opts), + SealedIndexing::IvfRq(x) => x.vbase(vector, opts), SealedIndexing::Hnsw(x) => x.vbase(vector, opts), SealedIndexing::HnswPq(x) => x.vbase(vector, opts), SealedIndexing::HnswSq(x) => x.vbase(vector, opts), + SealedIndexing::HnswRq(x) => x.vbase(vector, opts), SealedIndexing::InvertedIndex(x) => x.vbase(vector, opts), - SealedIndexing::Rabitq(x) => x.vbase(vector, opts), + } + } + + pub fn as_any(&self) -> &dyn Any { + match &self { + SealedIndexing::Flat(x) => x, + SealedIndexing::FlatPq(x) => x, + SealedIndexing::FlatSq(x) => x, + SealedIndexing::FlatRq(x) => x, + SealedIndexing::Ivf(x) => x, + SealedIndexing::IvfPq(x) => x, + SealedIndexing::IvfSq(x) => x, + SealedIndexing::IvfRq(x) => x, + SealedIndexing::Hnsw(x) => x, + SealedIndexing::HnswPq(x) => x, + SealedIndexing::HnswSq(x) => x, + SealedIndexing::HnswRq(x) => x, + SealedIndexing::InvertedIndex(x) => x, } } } @@ -128,14 +161,16 @@ impl Vectors> for SealedIndexing { SealedIndexing::Flat(x) => x.dims(), SealedIndexing::FlatSq(x) => x.dims(), SealedIndexing::FlatPq(x) => x.dims(), + SealedIndexing::FlatRq(x) => x.dims(), SealedIndexing::Ivf(x) => x.dims(), SealedIndexing::IvfSq(x) => x.dims(), SealedIndexing::IvfPq(x) => x.dims(), + SealedIndexing::IvfRq(x) => x.dims(), SealedIndexing::Hnsw(x) => x.dims(), SealedIndexing::HnswPq(x) => x.dims(), SealedIndexing::HnswSq(x) => x.dims(), + SealedIndexing::HnswRq(x) => x.dims(), SealedIndexing::InvertedIndex(x) => x.dims(), - SealedIndexing::Rabitq(x) => x.dims(), } } @@ -144,14 +179,16 @@ impl Vectors> for SealedIndexing { SealedIndexing::Flat(x) => x.len(), SealedIndexing::FlatPq(x) => x.len(), SealedIndexing::FlatSq(x) => x.len(), + SealedIndexing::FlatRq(x) => x.len(), SealedIndexing::Ivf(x) => x.len(), SealedIndexing::IvfPq(x) => x.len(), SealedIndexing::IvfSq(x) => x.len(), + SealedIndexing::IvfRq(x) => x.len(), SealedIndexing::Hnsw(x) => x.len(), SealedIndexing::HnswPq(x) => x.len(), SealedIndexing::HnswSq(x) => x.len(), + SealedIndexing::HnswRq(x) => x.len(), SealedIndexing::InvertedIndex(x) => x.len(), - SealedIndexing::Rabitq(x) => x.len(), } } @@ -160,14 +197,16 @@ impl Vectors> for SealedIndexing { SealedIndexing::Flat(x) => x.vector(i), SealedIndexing::FlatPq(x) => x.vector(i), SealedIndexing::FlatSq(x) => x.vector(i), + SealedIndexing::FlatRq(x) => x.vector(i), SealedIndexing::Ivf(x) => x.vector(i), SealedIndexing::IvfSq(x) => x.vector(i), SealedIndexing::IvfPq(x) => x.vector(i), + SealedIndexing::IvfRq(x) => x.vector(i), SealedIndexing::Hnsw(x) => x.vector(i), SealedIndexing::HnswSq(x) => x.vector(i), SealedIndexing::HnswPq(x) => x.vector(i), + SealedIndexing::HnswRq(x) => x.vector(i), SealedIndexing::InvertedIndex(x) => x.vector(i), - SealedIndexing::Rabitq(x) => x.vector(i), } } } @@ -178,14 +217,16 @@ impl Collection for SealedIndexing { SealedIndexing::Flat(x) => x.payload(i), SealedIndexing::FlatPq(x) => x.payload(i), SealedIndexing::FlatSq(x) => x.payload(i), + SealedIndexing::FlatRq(x) => x.payload(i), SealedIndexing::Ivf(x) => x.payload(i), SealedIndexing::IvfPq(x) => x.payload(i), SealedIndexing::IvfSq(x) => x.payload(i), + SealedIndexing::IvfRq(x) => x.payload(i), SealedIndexing::Hnsw(x) => x.payload(i), SealedIndexing::HnswPq(x) => x.payload(i), SealedIndexing::HnswSq(x) => x.payload(i), + SealedIndexing::HnswRq(x) => x.payload(i), SealedIndexing::InvertedIndex(x) => x.payload(i), - SealedIndexing::Rabitq(x) => x.payload(i), } } } diff --git a/crates/ivf/src/lib.rs b/crates/ivf/src/lib.rs index 024bd7e2c..a52104f20 100644 --- a/crates/ivf/src/lib.rs +++ b/crates/ivf/src/lib.rs @@ -31,7 +31,7 @@ pub struct Ivf> { quantization: Quantization, payloads: MmapArray, offsets: Json>, - centroids: Json::Scalar>>, + projected_centroids: Json::Scalar>>, is_residual: Json, } @@ -70,23 +70,37 @@ impl> Ivf { vector: Borrowed<'a, O>, opts: &'a SearchOptions, ) -> Box + 'a> { + let projected_vector = self.quantization.project(vector); let lists = select( - k_means_lookup_many(O::interpret(vector), &self.centroids), + k_means_lookup_many( + O::interpret(projected_vector.as_borrowed()), + &self.projected_centroids, + ), opts.ivf_nprobe as usize, ); let mut heap = Q::flat_rerank_start(); - let mut lut = self.quantization.flat_rerank_preprocess(vector, opts); + let lut = if *self.is_residual { + None + } else { + Some(self.quantization.flat_rerank_preprocess(vector, opts)) + }; for i in lists.iter().map(|(_, i)| *i) { - if *self.is_residual { - let vector = O::residual(vector, &self.centroids[(i,)]); - lut = self - .quantization - .flat_rerank_preprocess(vector.as_borrowed(), opts); - } + let lut = if let Some(lut) = lut.as_ref() { + lut + } else { + &self.quantization.flat_rerank_preprocess( + O::residual( + projected_vector.as_borrowed(), + &self.projected_centroids[(i,)], + ) + .as_borrowed(), + opts, + ) + }; let start = self.offsets[i]; let end = self.offsets[i + 1]; self.quantization - .flat_rerank_continue(&lut, start..end, &mut heap); + .flat_rerank_continue(lut, start..end, &mut heap); } let mut reranker = self.quantization.flat_rerank_break( heap, @@ -116,7 +130,7 @@ fn from_nothing>( } = options.indexing.clone().unwrap_ivf(); let samples = O::sample(collection, nlist); rayon::check(); - let centroids = k_means(nlist as usize, samples, true, spherical_centroids, false); + let centroids = k_means(nlist as usize, samples, spherical_centroids, 10, false); rayon::check(); let fa = (0..collection.len()) .into_par_iter() @@ -174,14 +188,21 @@ fn from_nothing>( (0..collection.len()).map(|i| collection.payload(i)), ); let offsets = Json::create(path.as_ref().join("offsets"), offsets); - let centroids = Json::create(path.as_ref().join("centroids"), centroids); + let projected_centroids = Json::create(path.as_ref().join("projected_centroids"), { + let mut projected_centroids = Vec2::zeros(centroids.shape()); + for i in 0..centroids.shape_0() { + projected_centroids[(i,)] + .copy_from_slice(&O::project(quantization.quantizer(), ¢roids[(i,)])); + } + projected_centroids + }); let is_residual = Json::create(path.as_ref().join("is_residual"), is_residual); Ivf { storage, quantization, payloads, offsets, - centroids, + projected_centroids, is_residual, } } @@ -191,14 +212,14 @@ fn open>(path: impl AsRef) -> Ivf { let quantization = Quantization::open(path.as_ref().join("quantization")); let payloads = MmapArray::open(path.as_ref().join("payloads")); let offsets = Json::open(path.as_ref().join("offsets")); - let centroids = Json::open(path.as_ref().join("centroids")); + let projected_centroids = Json::open(path.as_ref().join("projected_centroids")); let is_residual = Json::open(path.as_ref().join("is_residual")); Ivf { storage, quantization, payloads, offsets, - centroids, + projected_centroids, is_residual, } } diff --git a/crates/ivf/src/operator.rs b/crates/ivf/src/operator.rs index f99f4c0ef..7e2706b63 100644 --- a/crates/ivf/src/operator.rs +++ b/crates/ivf/src/operator.rs @@ -4,31 +4,33 @@ use base::scalar::ScalarLike; use base::search::Vectors; use base::vector::*; use common::vec2::Vec2; +use quantization::quantizer::Quantizer; use storage::OperatorStorage; pub trait OperatorIvf: OperatorStorage { const SUPPORT: bool; type Scalar: ScalarLike; - fn sample( - vectors: &impl Vectors, - nlist: u32, - ) -> Vec2<::Scalar>; - fn interpret(vector: Borrowed<'_, Self>) -> &[::Scalar]; + fn sample(vectors: &impl Vectors, nlist: u32) -> Vec2; + fn interpret(vector: Borrowed<'_, Self>) -> &[Self::Scalar]; + fn project>(quantizer: &Q, slice: &[Self::Scalar]) -> Vec; const SUPPORT_RESIDUAL: bool; - fn residual(lhs: Borrowed<'_, Self>, rhs: &[::Scalar]) -> Owned; + fn residual(lhs: Borrowed<'_, Self>, rhs: &[Self::Scalar]) -> Owned; } impl OperatorIvf for BVectorDot { const SUPPORT: bool = false; type Scalar = Impossible; - fn sample(_: &impl Vectors, _: u32) -> Vec2<::Scalar> { + fn sample(_: &impl Vectors, _: u32) -> Vec2 { unimplemented!() } - fn interpret(_: Borrowed<'_, Self>) -> &[::Scalar] { + fn interpret(_: Borrowed<'_, Self>) -> &[Self::Scalar] { + unimplemented!() + } + fn project>(_: &Q, _: &[Self::Scalar]) -> Vec { unimplemented!() } const SUPPORT_RESIDUAL: bool = false; - fn residual(_lhs: Borrowed<'_, Self>, _rhs: &[::Scalar]) -> Owned { + fn residual(_lhs: Borrowed<'_, Self>, _rhs: &[Self::Scalar]) -> Owned { unimplemented!() } } @@ -36,14 +38,17 @@ impl OperatorIvf for BVectorDot { impl OperatorIvf for BVectorJaccard { const SUPPORT: bool = false; type Scalar = Impossible; - fn sample(_: &impl Vectors, _: u32) -> Vec2<::Scalar> { + fn sample(_: &impl Vectors, _: u32) -> Vec2 { + unimplemented!() + } + fn interpret(_: Borrowed<'_, Self>) -> &[Self::Scalar] { unimplemented!() } - fn interpret(_: Borrowed<'_, Self>) -> &[::Scalar] { + fn project>(_: &Q, _: &[Self::Scalar]) -> Vec { unimplemented!() } const SUPPORT_RESIDUAL: bool = false; - fn residual(_lhs: Borrowed<'_, Self>, _rhs: &[::Scalar]) -> Owned { + fn residual(_lhs: Borrowed<'_, Self>, _rhs: &[Self::Scalar]) -> Owned { unimplemented!() } } @@ -51,14 +56,17 @@ impl OperatorIvf for BVectorJaccard { impl OperatorIvf for BVectorHamming { const SUPPORT: bool = false; type Scalar = Impossible; - fn sample(_: &impl Vectors, _: u32) -> Vec2<::Scalar> { + fn sample(_: &impl Vectors, _: u32) -> Vec2 { + unimplemented!() + } + fn interpret(_: Borrowed<'_, Self>) -> &[Self::Scalar] { unimplemented!() } - fn interpret(_: Borrowed<'_, Self>) -> &[::Scalar] { + fn project>(_: &Q, _: &[Self::Scalar]) -> Vec { unimplemented!() } const SUPPORT_RESIDUAL: bool = false; - fn residual(_lhs: Borrowed<'_, Self>, _rhs: &[::Scalar]) -> Owned { + fn residual(_lhs: Borrowed<'_, Self>, _rhs: &[Self::Scalar]) -> Owned { unimplemented!() } } @@ -66,14 +74,17 @@ impl OperatorIvf for BVectorHamming { impl OperatorIvf for SVectDot { const SUPPORT: bool = false; type Scalar = Impossible; - fn sample(_: &impl Vectors, _: u32) -> Vec2<::Scalar> { + fn sample(_: &impl Vectors, _: u32) -> Vec2 { unimplemented!() } - fn interpret(_: Borrowed<'_, Self>) -> &[::Scalar] { + fn interpret(_: Borrowed<'_, Self>) -> &[Self::Scalar] { + unimplemented!() + } + fn project>(_: &Q, _: &[Self::Scalar]) -> Vec { unimplemented!() } const SUPPORT_RESIDUAL: bool = false; - fn residual(_lhs: Borrowed<'_, Self>, _rhs: &[::Scalar]) -> Owned { + fn residual(_lhs: Borrowed<'_, Self>, _rhs: &[Self::Scalar]) -> Owned { unimplemented!() } } @@ -81,14 +92,17 @@ impl OperatorIvf for SVectDot { impl OperatorIvf for SVectL2 { const SUPPORT: bool = false; type Scalar = Impossible; - fn sample(_: &impl Vectors, _: u32) -> Vec2<::Scalar> { + fn sample(_: &impl Vectors, _: u32) -> Vec2 { + unimplemented!() + } + fn interpret(_: Borrowed<'_, Self>) -> &[Self::Scalar] { unimplemented!() } - fn interpret(_: Borrowed<'_, Self>) -> &[::Scalar] { + fn project>(_: &Q, _: &[Self::Scalar]) -> Vec { unimplemented!() } const SUPPORT_RESIDUAL: bool = false; - fn residual(_lhs: Borrowed<'_, Self>, _rhs: &[::Scalar]) -> Owned { + fn residual(_lhs: Borrowed<'_, Self>, _rhs: &[Self::Scalar]) -> Owned { unimplemented!() } } @@ -96,10 +110,7 @@ impl OperatorIvf for SVectL2 { impl OperatorIvf for VectDot { const SUPPORT: bool = true; type Scalar = S; - fn sample( - vectors: &impl Vectors, - nlist: u32, - ) -> Vec2<::Scalar> { + fn sample(vectors: &impl Vectors, nlist: u32) -> Vec2 { common::sample::sample( vectors.len(), nlist.saturating_mul(256).min(1 << 20), @@ -107,9 +118,12 @@ impl OperatorIvf for VectDot { |i| vectors.vector(i).slice(), ) } - fn interpret(x: Borrowed<'_, Self>) -> &[::Scalar] { + fn interpret(x: Borrowed<'_, Self>) -> &[Self::Scalar] { x.slice() } + fn project>(quantizer: &Q, centroid: &[Self::Scalar]) -> Vec { + quantizer.project(VectBorrowed::new(centroid)).into_vec() + } const SUPPORT_RESIDUAL: bool = false; fn residual(_lhs: Borrowed<'_, Self>, _rhs: &[S]) -> Owned { unimplemented!() @@ -119,10 +133,7 @@ impl OperatorIvf for VectDot { impl OperatorIvf for VectL2 { const SUPPORT: bool = true; type Scalar = S; - fn sample( - vectors: &impl Vectors, - nlist: u32, - ) -> Vec2<::Scalar> { + fn sample(vectors: &impl Vectors, nlist: u32) -> Vec2 { common::sample::sample( vectors.len(), nlist.saturating_mul(256).min(1 << 20), @@ -130,9 +141,12 @@ impl OperatorIvf for VectL2 { |i| vectors.vector(i).slice(), ) } - fn interpret(x: Borrowed<'_, Self>) -> &[::Scalar] { + fn interpret(x: Borrowed<'_, Self>) -> &[Self::Scalar] { x.slice() } + fn project>(quantizer: &Q, vector: &[Self::Scalar]) -> Vec { + quantizer.project(VectBorrowed::new(vector)).into_vec() + } const SUPPORT_RESIDUAL: bool = true; fn residual(lhs: Borrowed<'_, Self>, rhs: &[S]) -> Owned { lhs.operator_sub(VectBorrowed::new(rhs)) diff --git a/crates/k_means/src/elkan.rs b/crates/k_means/src/elkan.rs deleted file mode 100644 index 0aaa9e747..000000000 --- a/crates/k_means/src/elkan.rs +++ /dev/null @@ -1,264 +0,0 @@ -use base::scalar::*; -use common::vec2::Vec2; -use half::f16; -use rand::rngs::StdRng; -use rand::{Rng, SeedableRng}; -use std::ops::{Index, IndexMut}; - -pub struct ElkanKMeans { - dims: usize, - c: usize, - is_spherical: bool, - centroids: Vec2, - lowerbound: Square, - upperbound: Vec, - assign: Vec, - rng: StdRng, - samples: Vec2, - first: bool, -} - -const DELTA: f32 = f16::EPSILON.to_f32_const(); - -impl ElkanKMeans { - pub fn new(c: usize, samples: Vec2, is_spherical: bool) -> Self { - let n = samples.shape_0(); - let dims = samples.shape_1(); - - let mut rng = StdRng::from_entropy(); - let mut centroids = Vec2::zeros((c, dims)); - let mut lowerbound = Square::new(n, c); - let mut upperbound = vec![0.0f32; n]; - let mut assign = vec![0usize; n]; - - centroids[(0,)].copy_from_slice(&samples[(rng.gen_range(0..n),)]); - - let mut weight = vec![f32::INFINITY; n]; - let mut dis = vec![0.0f32; n]; - for i in 0..c { - let mut sum = 0.0f32; - for j in 0..n { - dis[j] = S::reduce_sum_of_d2(&samples[(j,)], ¢roids[(i,)]).sqrt(); - } - for j in 0..n { - lowerbound[(j, i)] = dis[j]; - if dis[j] * dis[j] < weight[j] { - weight[j] = dis[j] * dis[j]; - } - sum += weight[j]; - } - if i + 1 == c { - break; - } - let index = 'a: { - let mut choice = sum * rng.gen_range(0.0..1.0); - for j in 0..(n - 1) { - choice -= weight[j]; - if choice < 0.0f32 { - break 'a j; - } - } - n - 1 - }; - centroids[(i + 1,)].copy_from_slice(&samples[(index,)]); - } - - for i in 0..n { - let mut minimal = f32::INFINITY; - let mut target = 0; - for j in 0..c { - let dis = lowerbound[(i, j)]; - if dis < minimal { - minimal = dis; - target = j; - } - } - assign[i] = target; - upperbound[i] = minimal; - } - - Self { - dims, - c, - is_spherical, - centroids, - lowerbound, - upperbound, - assign, - rng, - samples, - first: true, - } - } - - pub fn iterate(&mut self) -> bool { - let c = self.c; - let dims = self.dims; - let samples = &self.samples; - let rand = &mut self.rng; - let assign = &mut self.assign; - let centroids = &mut self.centroids; - let lowerbound = &mut self.lowerbound; - let upperbound = &mut self.upperbound; - let mut change = 0; - let n = samples.shape_0(); - // Step 1 - let mut dist0 = Square::new(c, c); - let mut sp = vec![0.0f32; c]; - for i in 0..c { - for j in 0..c { - dist0[(i, j)] = - S::reduce_sum_of_d2(¢roids[(i,)], ¢roids[(j,)]).sqrt() * 0.5; - } - } - for i in 0..c { - let mut minimal = f32::INFINITY; - for j in 0..c { - if i == j { - continue; - } - let dis = dist0[(i, j)]; - if dis < minimal { - minimal = dis; - } - } - sp[i] = minimal; - } - let mut dis = vec![0.0f32; n]; - for i in 0..n { - if upperbound[i] > sp[assign[i]] { - dis[i] = S::reduce_sum_of_d2(&samples[(i,)], ¢roids[(assign[i],)]).sqrt(); - } - } - for i in 0..n { - // Step 2 - if upperbound[i] <= sp[assign[i]] { - continue; - } - let mut minimal = dis[i]; - lowerbound[(i, assign[i])] = minimal; - upperbound[i] = minimal; - // Step 3 - for j in 0..c { - if j == assign[i] { - continue; - } - if upperbound[i] <= lowerbound[(i, j)] { - continue; - } - if upperbound[i] <= dist0[(assign[i], j)] { - continue; - } - if minimal > lowerbound[(i, j)] || minimal > dist0[(assign[i], j)] { - let dis = S::reduce_sum_of_d2(&samples[(i,)], ¢roids[(j,)]).sqrt(); - lowerbound[(i, j)] = dis; - if dis < minimal { - minimal = dis; - assign[i] = j; - upperbound[i] = dis; - change += 1; - } - } - } - } - - // Step 4, 7 - let old_centroids = std::mem::replace(centroids, Vec2::zeros((c, dims))); - let mut count = vec![0.0f32; c]; - for i in 0..n { - S::vector_add_inplace(&mut centroids[(self.assign[i],)], &samples[(i,)]); - count[self.assign[i]] += 1.0; - } - for i in 0..c { - if count[i] == 0.0f32 { - continue; - } - S::vector_mul_scalar_inplace(&mut centroids[(i,)], 1.0 / count[i]); - } - for i in 0..c { - if count[i] != 0.0f32 { - continue; - } - let mut o = 0; - loop { - let alpha = f32::from_f32(rand.gen_range(0.0..1.0f32)); - let beta = (count[o] - 1.0) / (n - c) as f32; - if alpha < beta { - break; - } - o = (o + 1) % c; - } - centroids.copy_within((o,), (i,)); - S::kmeans_helper(&mut centroids[(i,)], 1.0 + DELTA, 1.0 - DELTA); - S::kmeans_helper(&mut centroids[(o,)], 1.0 - DELTA, 1.0 + DELTA); - count[i] = count[o] / 2.0; - count[o] -= count[i]; - } - - if self.is_spherical { - for i in 0..c { - let centroid = &mut centroids[(i,)]; - let l = S::reduce_sum_of_x2(centroid).sqrt(); - S::vector_mul_scalar_inplace(centroid, 1.0 / l); - } - } - - // Step 5, 6 - let mut dist1 = vec![0.0f32; c]; - for i in 0..c { - dist1[i] = S::reduce_sum_of_d2(&old_centroids[(i,)], ¢roids[(i,)]).sqrt(); - } - for i in 0..n { - for j in 0..c { - self.lowerbound[(i, j)] = 0.0f32.max(self.lowerbound[(i, j)] - dist1[j]); - } - } - for i in 0..n { - self.upperbound[i] += dist1[self.assign[i]]; - } - if self.first { - self.first = false; - false - } else { - change == 0 - } - } - - pub fn finish(self) -> Vec2 { - self.centroids - } -} - -struct Square { - x: usize, - y: usize, - v: Vec, -} - -impl Square { - pub fn new(x: usize, y: usize) -> Self { - Self { - x, - y, - v: base::pod::zeroed_vec(x * y), - } - } -} - -impl Index<(usize, usize)> for Square { - type Output = f32; - - fn index(&self, (x, y): (usize, usize)) -> &Self::Output { - debug_assert!(x < self.x); - debug_assert!(y < self.y); - &self.v[x * self.y + y] - } -} - -impl IndexMut<(usize, usize)> for Square { - fn index_mut(&mut self, (x, y): (usize, usize)) -> &mut Self::Output { - debug_assert!(x < self.x); - debug_assert!(y < self.y); - &mut self.v[x * self.y + y] - } -} diff --git a/crates/k_means/src/lib.rs b/crates/k_means/src/lib.rs index b57d57678..702163870 100644 --- a/crates/k_means/src/lib.rs +++ b/crates/k_means/src/lib.rs @@ -1,13 +1,11 @@ #![allow(clippy::needless_range_loop)] -pub mod elkan; pub mod kmeans1d; pub mod lloyd; pub mod quick_centers; use base::scalar::*; use common::vec2::Vec2; -use elkan::ElkanKMeans; use kmeans1d::kmeans1d; use lloyd::LloydKMeans; use stoppable_rayon as rayon; @@ -15,8 +13,8 @@ use stoppable_rayon as rayon; pub fn k_means( c: usize, mut samples: Vec2, - prefer_multithreading: bool, is_spherical: bool, + iterations: usize, prefer_kmeanspp: bool, ) -> Vec2 { assert!(c > 0); @@ -38,25 +36,14 @@ pub fn k_means( let centroids = S::vector_from_f32(&kmeans1d(c, samples.as_slice())); return Vec2::from_vec((c, 1), centroids); } - if prefer_multithreading { - let mut lloyd_k_means = LloydKMeans::new(c, samples, is_spherical, prefer_kmeanspp); - for _ in 0..25 { - rayon::check(); - if lloyd_k_means.iterate() { - break; - } + let mut lloyd_k_means = LloydKMeans::new(c, samples, is_spherical, prefer_kmeanspp); + for _ in 0..iterations { + rayon::check(); + if lloyd_k_means.iterate() { + break; } - lloyd_k_means.finish() - } else { - let mut elkan_k_means = ElkanKMeans::new(c, samples, is_spherical); - for _ in 0..100 { - rayon::check(); - if elkan_k_means.iterate() { - break; - } - } - elkan_k_means.finish() } + lloyd_k_means.finish() } pub fn k_means_lookup(vector: &[S], centroids: &Vec2) -> usize { diff --git a/crates/quantization/Cargo.toml b/crates/quantization/Cargo.toml index 75852c219..99d3fb3e6 100644 --- a/crates/quantization/Cargo.toml +++ b/crates/quantization/Cargo.toml @@ -5,7 +5,10 @@ edition.workspace = true [dependencies] log.workspace = true +nalgebra = "0.33.0" rand.workspace = true +rand_chacha = "0.3.1" +rand_distr.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/crates/quantization/src/lib.rs b/crates/quantization/src/lib.rs index fe36a952e..412fd4753 100644 --- a/crates/quantization/src/lib.rs +++ b/crates/quantization/src/lib.rs @@ -9,6 +9,7 @@ pub mod fast_scan; pub mod product; pub mod quantize; pub mod quantizer; +pub mod rabitq; pub mod reranker; pub mod scalar; pub mod trivial; @@ -22,14 +23,17 @@ use base::vector::VectorOwned; use common::json::Json; use common::mmap_array::MmapArray; use quantizer::Quantizer; +use rayon::iter::IntoParallelIterator; +use rayon::iter::ParallelIterator; use std::marker::PhantomData; use std::ops::Range; use std::path::Path; +use stoppable_rayon as rayon; pub struct Quantization { - train: Json, + quantizer: Json, codes: MmapArray, - packed_codes: MmapArray, + fcodes: MmapArray, _maker: PhantomData O>, } @@ -42,53 +46,73 @@ impl> Quantization { transform: impl Fn(Borrowed<'_, O>) -> Owned + Copy + Send + Sync, ) -> Self { std::fs::create_dir(path.as_ref()).unwrap(); - let train = Q::train(vector_options, quantization_options, vectors, transform); - let train = Json::create(path.as_ref().join("train"), train); + let quantizer = Json::create( + path.as_ref().join("quantizer"), + Q::train(vector_options, quantization_options, vectors, transform), + ); let codes = MmapArray::create(path.as_ref().join("codes"), { - (0..vectors.len()).flat_map(|i| { - let vector = transform(vectors.vector(i)); - train.encode(vector.as_borrowed()) - }) + (0..vectors.len()) + .into_par_iter() + .map(|i| { + let vector = quantizer.project(transform(vectors.vector(i)).as_borrowed()); + quantizer.encode(vector.as_borrowed()) + }) + .collect::>() + .into_iter() + .flatten() }); - let packed_codes = MmapArray::create(path.as_ref().join("packed_codes"), { + let fcodes = MmapArray::create(path.as_ref().join("fcodes"), { let d = vectors.dims(); let n = vectors.len(); let m = n.div_ceil(32); - let train = &train; - (0..m).flat_map(move |alpha| { - let vectors = std::array::from_fn(|beta| { - let i = 32 * alpha + beta as u32; - if i < n { - transform(vectors.vector(i)) - } else { - O::Vector::zero(d) - } - }); - train.fscan_encode(vectors) - }) + let train = &quantizer; + (0..m) + .into_par_iter() + .map(move |alpha| { + let vectors = std::array::from_fn(|beta| { + let i = 32 * alpha + beta as u32; + if i < n { + train.project(transform(vectors.vector(i)).as_borrowed()) + } else { + O::Vector::zero(d) + } + }); + train.fscan_encode(vectors) + }) + .collect::>() + .into_iter() + .flatten() }); Self { - train, + quantizer, codes, - packed_codes, + fcodes, _maker: PhantomData, } } pub fn open(path: impl AsRef) -> Self { - let train = Json::open(path.as_ref().join("train")); + let quantizer = Json::open(path.as_ref().join("quantizer")); let codes = MmapArray::open(path.as_ref().join("codes")); - let packed_codes = MmapArray::open(path.as_ref().join("packed_codes")); + let fcodes = MmapArray::open(path.as_ref().join("fcodes")); Self { - train, + quantizer, codes, - packed_codes, + fcodes, _maker: PhantomData, } } + pub fn quantizer(&self) -> &Q { + &self.quantizer + } + + pub fn project(&self, vector: Borrowed<'_, O>) -> Owned { + Q::project(&self.quantizer, vector) + } + pub fn preprocess(&self, vector: Borrowed<'_, O>) -> Q::Lut { - Q::preprocess(&self.train, vector) + Q::preprocess(&self.quantizer, vector) } pub fn flat_rerank_preprocess( @@ -96,18 +120,18 @@ impl> Quantization { vector: Borrowed<'_, O>, opts: &SearchOptions, ) -> Result { - Q::flat_rerank_preprocess(&self.train, vector, opts) + Q::flat_rerank_preprocess(&self.quantizer, vector, opts) } pub fn process(&self, vectors: &impl Vectors>, lut: &Q::Lut, u: u32) -> Distance { let locate = |i| { - let code_size = self.train.code_size() as usize; + let code_size = self.quantizer.code_size() as usize; let start = i as usize * code_size; let end = start + code_size; &self.codes[start..end] }; let vector = vectors.vector(u); - Q::process(&self.train, lut, locate(u), vector) + Q::process(&self.quantizer, lut, locate(u), vector) } pub fn flat_rerank_continue( @@ -117,18 +141,18 @@ impl> Quantization { heap: &mut Q::FlatRerankVec, ) { Q::flat_rerank_continue( - &self.train, + &self.quantizer, |i| { - let code_size = self.train.code_size() as usize; + let code_size = self.quantizer.code_size() as usize; let start = i as usize * code_size; let end = start + code_size; &self.codes[start..end] }, |i| { - let fcode_size = self.train.fcode_size() as usize; + let fcode_size = self.quantizer.fcode_size() as usize; let start = i as usize * fcode_size; let end = start + fcode_size; - &self.packed_codes[start..end] + &self.fcodes[start..end] }, frlut, range, @@ -145,23 +169,23 @@ impl> Quantization { where R: Fn(u32) -> (Distance, T) + 'a, { - Q::flat_rerank_break(&self.train, heap, rerank, opts) + Q::flat_rerank_break(&self.quantizer, heap, rerank, opts) } pub fn graph_rerank<'a, T: 'a, R: Fn(u32) -> (Distance, T) + 'a>( &'a self, - vector: Borrowed<'a, O>, + lut: Q::Lut, rerank: R, ) -> impl RerankerPush + RerankerPop + 'a { Q::graph_rerank( - &self.train, + &self.quantizer, + lut, |i| { - let code_size = self.train.code_size() as usize; + let code_size = self.quantizer.code_size() as usize; let start = i as usize * code_size; let end = start + code_size; &self.codes[start..end] }, - vector, rerank, ) } diff --git a/crates/quantization/src/product.rs b/crates/quantization/src/product.rs index 76c1ddadb..e6efbb3cc 100644 --- a/crates/quantization/src/product.rs +++ b/crates/quantization/src/product.rs @@ -15,6 +15,7 @@ use base::operator::*; use base::scalar::impossible::Impossible; use base::scalar::ScalarLike; use base::search::*; +use base::vector::VectorBorrowed; use base::vector::VectorOwned; use common::sample::sample; use common::vec2::Vec2; @@ -66,7 +67,7 @@ impl Quantizer for ProductQuantizer { ) .to_vec() }); - k_means(1 << bits, subsamples, false, false, true) + k_means(1 << bits, subsamples, false, 25, true) }) .collect::>(); let mut centroids = Vec2::zeros((1 << bits, dims as usize)); @@ -151,6 +152,10 @@ impl Quantizer for ProductQuantizer { } } + fn project(&self, vector: Borrowed<'_, O>) -> Owned { + vector.own() + } + type Lut = Vec; fn preprocess(&self, vector: Borrowed<'_, O>) -> Self::Lut { @@ -184,7 +189,7 @@ impl Quantizer for ProductQuantizer { ) } - fn fscan_process(flut: &Self::FLut, code: &[u8]) -> [Distance; 32] { + fn fscan_process(&self, flut: &Self::FLut, code: &[u8]) -> [Distance; 32] { O::fscan_process(flut, code) } @@ -219,19 +224,27 @@ impl Quantizer for ProductQuantizer { match frlut { Ok(flut) => { fn divide(r: Range) -> (Option, Range, Option) { - if r.start > r.end || r.start % 32 == 0 && r.end % 32 == 0 { - (None, r.start / 32..r.end / 32, None) - } else if r.start / 32 == r.end / 32 { - (Some(r.start / 32), 0..0, None) - } else { - let left = (r.start % 32 != 0).then_some(r.start / 32); - let right = (r.end % 32 != 0).then_some(r.end / 32); - (left, r.start / 32 + 1..r.end / 32, right) + if r.start > r.end { + return (None, r.start / 32..r.end / 32, None); } + if r.start / 32 == r.end / 32 { + return (Some(r.start / 32), 0..0, None); + }; + let left = if r.start % 32 == 0 { + (None, r.start / 32) + } else { + (Some(r.start / 32), r.start / 32 + 1) + }; + let right = if r.end % 32 == 0 { + (r.end / 32, None) + } else { + (r.end / 32, Some(r.end / 32)) + }; + (left.0, left.1..right.0, right.1) } let (left, main, right) = divide(range.clone()); if let Some(i) = left { - let r = Self::fscan_process(flut, locate_1(i).as_ref()); + let r = self.fscan_process(flut, locate_1(i).as_ref()); for j in 0..32 { if range.contains(&(i * 32 + j)) { heap.push((Reverse(r[j as usize]), AlwaysEqual(i * 32 + j))); @@ -239,13 +252,13 @@ impl Quantizer for ProductQuantizer { } } for i in main { - let r = Self::fscan_process(flut, locate_1(i).as_ref()); + let r = self.fscan_process(flut, locate_1(i).as_ref()); for j in 0..32 { heap.push((Reverse(r[j as usize]), AlwaysEqual(i * 32 + j))); } } if let Some(i) = right { - let r = Self::fscan_process(flut, locate_1(i).as_ref()); + let r = self.fscan_process(flut, locate_1(i).as_ref()); for j in 0..32 { if range.contains(&(i * 32 + j)) { heap.push((Reverse(r[j as usize]), AlwaysEqual(i * 32 + j))); @@ -276,8 +289,8 @@ impl Quantizer for ProductQuantizer { fn graph_rerank<'a, T, R, C>( &'a self, + lut: Self::Lut, locate: impl Fn(u32) -> C + 'a, - vector: Borrowed<'a, O>, rerank: R, ) -> impl RerankerPush + RerankerPop + 'a where @@ -285,9 +298,8 @@ impl Quantizer for ProductQuantizer { R: Fn(u32) -> (Distance, T) + 'a, C: AsRef<[u8]>, { - let lut = self.preprocess(vector); Graph2Reranker::new( - move |u| self.process(&lut, locate(u).as_ref(), vector), + move |u| O::process(self.dims, self.ratio, self.bits, &lut, locate(u).as_ref()), rerank, ) } diff --git a/crates/quantization/src/quantizer.rs b/crates/quantization/src/quantizer.rs index 0b27533e8..1c2800ab6 100644 --- a/crates/quantization/src/quantizer.rs +++ b/crates/quantization/src/quantizer.rs @@ -21,13 +21,15 @@ pub trait Quantizer: fn code_size(&self) -> u32; fn fcode_size(&self) -> u32; + fn project(&self, vector: Borrowed<'_, O>) -> Owned; + type Lut; fn preprocess(&self, vector: Borrowed<'_, O>) -> Self::Lut; fn process(&self, lut: &Self::Lut, code: &[u8], vector: Borrowed<'_, O>) -> Distance; type FLut; fn fscan_preprocess(&self, vector: Borrowed<'_, O>) -> Self::FLut; - fn fscan_process(flut: &Self::FLut, code: &[u8]) -> [Distance; 32]; + fn fscan_process(&self, flut: &Self::FLut, code: &[u8]) -> [Distance; 32]; type FlatRerankVec; @@ -60,8 +62,8 @@ pub trait Quantizer: fn graph_rerank<'a, T, R, C>( &'a self, + lut: Self::Lut, locate: impl Fn(u32) -> C + 'a, - vector: Borrowed<'a, O>, rerank: R, ) -> impl RerankerPush + RerankerPop + 'a where diff --git a/crates/quantization/src/rabitq.rs b/crates/quantization/src/rabitq.rs new file mode 100644 index 000000000..f471cadb2 --- /dev/null +++ b/crates/quantization/src/rabitq.rs @@ -0,0 +1,725 @@ +use crate::fast_scan::b4::fast_scan_b4; +use crate::fast_scan::b4::pack; +use crate::quantizer::Quantizer; +use crate::reranker::error::ErrorFlatReranker; +use crate::reranker::graph_2::Graph2Reranker; +use crate::utils::InfiniteByteChunks; +use base::always_equal::AlwaysEqual; +use base::distance::Distance; +use base::index::*; +use base::operator::*; +use base::scalar::impossible::Impossible; +use base::scalar::ScalarLike; +use base::search::*; +use base::vector::VectOwned; +use base::vector::VectorBorrowed; +use base::vector::VectorOwned; +use serde::{Deserialize, Serialize}; +use std::cmp::Reverse; +use std::marker::PhantomData; +use std::ops::Range; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct RabitqQuantizer { + dims: u32, + projection: Vec>, + _maker: PhantomData, +} + +impl Quantizer for RabitqQuantizer { + fn train( + vector_options: VectorOptions, + _: Option, + _: &(impl Vectors> + Sync), + _: impl Fn(Borrowed<'_, O>) -> Owned + Copy + Sync, + ) -> Self { + let dims = vector_options.dims; + let projection = { + use nalgebra::{DMatrix, QR}; + use rand::{Rng, SeedableRng}; + use rand_chacha::ChaCha12Rng; + use rand_distr::StandardNormal; + let mut rng = ChaCha12Rng::from_seed([7; 32]); + let matrix: Vec = (0..dims as usize * dims as usize) + .map(|_| rng.sample(StandardNormal)) + .collect(); + let matrix = DMatrix::from_vec(dims as usize, dims as usize, matrix); + let qr = QR::new(matrix); + let q = qr.q(); + let mut projection = Vec::new(); + for v in q.row_iter() { + let vector = v.iter().copied().collect::>(); + projection.push(O::Scalar::vector_from_f32(&vector)); + } + projection + }; + Self { + dims, + projection, + _maker: PhantomData, + } + } + + fn encode(&self, vector: Borrowed<'_, O>) -> Vec { + let dims = self.dims; + let (a, b, c, d, e) = O::code(vector); + let mut result = Vec::with_capacity(size_of::() * 4); + result.extend(a.to_ne_bytes()); + result.extend(b.to_ne_bytes()); + result.extend(c.to_ne_bytes()); + result.extend(d.to_ne_bytes()); + for x in InfiniteByteChunks::<_, 64>::new(e.into_iter()).take(dims.div_ceil(64) as usize) { + let mut r = 0_u64; + for i in 0..64 { + r |= (x[i] as u64) << i; + } + result.extend(r.to_ne_bytes().into_iter()); + } + result + } + + fn fscan_encode(&self, vectors: [Owned; 32]) -> Vec { + let dims = self.dims; + let coded = vectors.map(|vector| O::code(vector.as_borrowed())); + let codes = coded.clone().map(|(_, _, _, _, e)| { + InfiniteByteChunks::new(e.into_iter()) + .map(|[b0, b1, b2, b3]| b0 | b1 << 1 | b2 << 2 | b3 << 3) + .take(dims.div_ceil(4) as usize) + .collect() + }); + let mut result = Vec::with_capacity(size_of::() * 128); + for i in 0..32 { + result.extend(coded[i].0.to_ne_bytes()); + } + for i in 0..32 { + result.extend(coded[i].1.to_ne_bytes()); + } + for i in 0..32 { + result.extend(coded[i].2.to_ne_bytes()); + } + for i in 0..32 { + result.extend(coded[i].3.to_ne_bytes()); + } + result.extend(pack(dims.div_ceil(4), codes)); + result + } + + fn code_size(&self) -> u32 { + size_of::() as u32 * 4 + size_of::() as u32 * self.dims.div_ceil(64) + } + + fn fcode_size(&self) -> u32 { + size_of::() as u32 * 128 + self.dims.div_ceil(4) * 16 + } + + type Lut = O::Lut; + + fn preprocess(&self, vector: Borrowed<'_, O>) -> Self::Lut { + O::preprocess(vector) + } + + fn process(&self, lut: &Self::Lut, code: &[u8], _: Borrowed<'_, O>) -> Distance { + let c = parse_code(code); + O::process(lut, c) + } + + fn project(&self, vector: Borrowed<'_, O>) -> Owned { + O::project(&self.projection, vector) + } + + type FLut = O::FLut; + + fn fscan_preprocess(&self, vector: Borrowed<'_, O>) -> Self::FLut { + O::fscan_preprocess(vector) + } + + fn fscan_process(&self, flut: &Self::FLut, code: &[u8]) -> [Distance; 32] { + let c = parses_codes(code); + O::fscan_process(self.dims, flut, c) + } + + type FlatRerankVec = Vec<(Reverse, AlwaysEqual)>; + + fn flat_rerank_start() -> Self::FlatRerankVec { + Vec::new() + } + + fn flat_rerank_preprocess( + &self, + vector: Borrowed<'_, O>, + opts: &SearchOptions, + ) -> Result { + if opts.rq_fast_scan { + Ok(self.fscan_preprocess(vector)) + } else { + Err(self.preprocess(vector)) + } + } + + fn flat_rerank_continue( + &self, + locate_0: impl Fn(u32) -> C, + locate_1: impl Fn(u32) -> C, + frlut: &Result, + range: Range, + heap: &mut Self::FlatRerankVec, + ) where + C: AsRef<[u8]>, + { + match frlut { + Ok(flut) => { + fn divide(r: Range) -> (Option, Range, Option) { + if r.start > r.end { + return (None, r.start / 32..r.end / 32, None); + } + if r.start / 32 == r.end / 32 { + return (Some(r.start / 32), 0..0, None); + }; + let left = if r.start % 32 == 0 { + (None, r.start / 32) + } else { + (Some(r.start / 32), r.start / 32 + 1) + }; + let right = if r.end % 32 == 0 { + (r.end / 32, None) + } else { + (r.end / 32, Some(r.end / 32)) + }; + (left.0, left.1..right.0, right.1) + } + let (left, main, right) = divide(range.clone()); + if let Some(i) = left { + let c = locate_1(i); + let c = parses_codes(c.as_ref()); + let r = O::fscan_process_lowerbound(self.dims, flut, c, 1.9); + for j in 0..32 { + if range.contains(&(i * 32 + j)) { + heap.push((Reverse(r[j as usize]), AlwaysEqual(i * 32 + j))); + } + } + } + for i in main { + let c = locate_1(i); + let c = parses_codes(c.as_ref()); + let r = O::fscan_process_lowerbound(self.dims, flut, c, 1.9); + for j in 0..32 { + heap.push((Reverse(r[j as usize]), AlwaysEqual(i * 32 + j))); + } + } + if let Some(i) = right { + let c = locate_1(i); + let c = parses_codes(c.as_ref()); + let r = O::fscan_process_lowerbound(self.dims, flut, c, 1.9); + for j in 0..32 { + if range.contains(&(i * 32 + j)) { + heap.push((Reverse(r[j as usize]), AlwaysEqual(i * 32 + j))); + } + } + } + } + Err(lut) => { + for j in range { + let c = locate_0(j); + let c = parse_code(c.as_ref()); + let r = O::process_lowerbound(lut, c, 1.9); + heap.push((Reverse(r), AlwaysEqual(j))); + } + } + } + } + + fn flat_rerank_break<'a, T: 'a, R>( + &'a self, + heap: Self::FlatRerankVec, + rerank: R, + _: &SearchOptions, + ) -> impl RerankerPop + 'a + where + R: Fn(u32) -> (Distance, T) + 'a, + { + ErrorFlatReranker::new(heap, rerank) + } + + fn graph_rerank<'a, T, R, C>( + &'a self, + lut: Self::Lut, + locate: impl Fn(u32) -> C + 'a, + rerank: R, + ) -> impl RerankerPush + RerankerPop + 'a + where + T: 'a, + R: Fn(u32) -> (Distance, T) + 'a, + C: AsRef<[u8]>, + { + Graph2Reranker::new( + move |u| O::process(&lut, parse_code(locate(u).as_ref())), + rerank, + ) + } +} + +pub trait OperatorRabitqQuantization: Operator { + type Scalar: ScalarLike; + + fn code(vector: Borrowed<'_, Self>) -> (f32, f32, f32, f32, Vec); + + fn project(projection: &[Vec], vector: Borrowed<'_, Self>) -> Owned; + + type Lut; + fn preprocess(vector: Borrowed<'_, Self>) -> Self::Lut; + fn process(lut: &Self::Lut, code: (f32, f32, f32, f32, &[u64])) -> Distance; + fn process_lowerbound( + lut: &Self::Lut, + code: (f32, f32, f32, f32, &[u64]), + epsilon: f32, + ) -> Distance; + + type FLut; + fn fscan_preprocess(vector: Borrowed<'_, Self>) -> Self::FLut; + fn fscan_process( + dims: u32, + lut: &Self::FLut, + code: (&[f32; 32], &[f32; 32], &[f32; 32], &[f32; 32], &[u8]), + ) -> [Distance; 32]; + fn fscan_process_lowerbound( + dims: u32, + lut: &Self::FLut, + code: (&[f32; 32], &[f32; 32], &[f32; 32], &[f32; 32], &[u8]), + epsilon: f32, + ) -> [Distance; 32]; +} + +impl OperatorRabitqQuantization for VectL2 { + type Scalar = S; + + fn code(vector: Borrowed<'_, Self>) -> (f32, f32, f32, f32, Vec) { + let dims = vector.dims(); + let vector = vector.slice(); + let sum_of_abs_x = S::reduce_sum_of_abs_x(vector); + let sum_of_x2 = S::reduce_sum_of_x2(vector); + let dis_u = sum_of_x2.sqrt(); + let x0 = sum_of_abs_x / (sum_of_x2 * (dims as f32)).sqrt(); + let x_x0 = dis_u / x0; + let fac_norm = (dims as f32).sqrt(); + let max_x1 = 1.0f32 / (dims as f32 - 1.0).sqrt(); + let factor_err = 2.0f32 * max_x1 * (x_x0 * x_x0 - dis_u * dis_u).sqrt(); + let factor_ip = -2.0f32 / fac_norm * x_x0; + let cnt_pos = vector + .iter() + .map(|x| x.scalar_is_sign_positive() as i32) + .sum::(); + let cnt_neg = vector + .iter() + .map(|x| x.scalar_is_sign_negative() as i32) + .sum::(); + let factor_ppc = factor_ip * (cnt_pos - cnt_neg) as f32; + let mut code = Vec::new(); + for i in 0..dims { + code.push(vector[i as usize].scalar_is_sign_positive() as u8); + } + (sum_of_x2, factor_ppc, factor_ip, factor_err, code) + } + + fn project(projection: &[Vec], vector: Borrowed<'_, Self>) -> Owned { + let slice = (0..projection.len()) + .map(|i| S::from_f32(S::reduce_sum_of_xy(&projection[i], vector.slice()))) + .collect(); + VectOwned::new(slice) + } + + type Lut = (f32, f32, f32, f32, (Vec, Vec, Vec, Vec)); + + fn preprocess(vector: Borrowed<'_, Self>) -> Self::Lut { + use crate::quantize; + let vector = vector.slice(); + let dis_v_2 = S::reduce_sum_of_x2(vector); + let (k, b, qvector) = quantize::quantize::<15>(S::vector_to_f32_borrowed(vector).as_ref()); + let qvector_sum = if vector.len() <= 4369 { + quantize::reduce_sum_of_x_as_u16(&qvector) as f32 + } else { + quantize::reduce_sum_of_x_as_u32(&qvector) as f32 + }; + let lut = binarize(&qvector); + (dis_v_2, b, k, qvector_sum, lut) + } + + fn process( + lut: &Self::Lut, + (dis_u_2, factor_ppc, factor_ip, _, t): (f32, f32, f32, f32, &[u64]), + ) -> Distance { + let &(dis_v_2, b, k, qvector_sum, ref s) = lut; + let value = asymmetric_binary_dot_product(t, s) as u16; + let rough = dis_u_2 + + dis_v_2 + + b * factor_ppc + + ((2.0 * value as f32) - qvector_sum) * factor_ip * k; + Distance::from_f32(rough) + } + + fn process_lowerbound( + lut: &Self::Lut, + (dis_u_2, factor_ppc, factor_ip, factor_err, t): (f32, f32, f32, f32, &[u64]), + epsilon: f32, + ) -> Distance { + let &(dis_v_2, b, k, qvector_sum, ref s) = lut; + let value = asymmetric_binary_dot_product(t, s) as u16; + let rough = dis_u_2 + + dis_v_2 + + b * factor_ppc + + ((2.0 * value as f32) - qvector_sum) * factor_ip * k; + let err = factor_err * dis_v_2.sqrt(); + Distance::from_f32(rough - epsilon * err) + } + + type FLut = (f32, f32, f32, f32, Vec); + + fn fscan_preprocess(vector: Borrowed<'_, Self>) -> Self::FLut { + use crate::quantize; + let vector = vector.slice(); + let dis_v_2 = S::reduce_sum_of_x2(vector); + let (k, b, qvector) = quantize::quantize::<15>(S::vector_to_f32_borrowed(vector).as_ref()); + let qvector_sum = if vector.len() <= 4369 { + quantize::reduce_sum_of_x_as_u16(&qvector) as f32 + } else { + quantize::reduce_sum_of_x_as_u32(&qvector) as f32 + }; + let lut = gen(qvector); + (dis_v_2, b, k, qvector_sum, lut) + } + + fn fscan_process( + dims: u32, + lut: &Self::FLut, + (dis_u_2, factor_ppc, factor_ip, _, t): ( + &[f32; 32], + &[f32; 32], + &[f32; 32], + &[f32; 32], + &[u8], + ), + ) -> [Distance; 32] { + let &(dis_v_2, b, k, qvector_sum, ref s) = lut; + let r = fast_scan_b4(dims.div_ceil(4), t, s); + std::array::from_fn(|i| { + let rough = dis_u_2[i] + + dis_v_2 + + b * factor_ppc[i] + + ((2.0 * r[i] as f32) - qvector_sum) * factor_ip[i] * k; + Distance::from_f32(rough) + }) + } + + fn fscan_process_lowerbound( + dims: u32, + lut: &Self::FLut, + (dis_u_2, factor_ppc, factor_ip, factor_err, t): ( + &[f32; 32], + &[f32; 32], + &[f32; 32], + &[f32; 32], + &[u8], + ), + epsilon: f32, + ) -> [Distance; 32] { + let &(dis_v_2, b, k, qvector_sum, ref s) = lut; + let r = fast_scan_b4(dims.div_ceil(4), t, s); + std::array::from_fn(|i| { + let rough = dis_u_2[i] + + dis_v_2 + + b * factor_ppc[i] + + ((2.0 * r[i] as f32) - qvector_sum) * factor_ip[i] * k; + let err = factor_err[i] * dis_v_2.sqrt(); + Distance::from_f32(rough - epsilon * err) + }) + } +} + +impl OperatorRabitqQuantization for VectDot { + type Scalar = S; + + fn code(vector: Borrowed<'_, Self>) -> (f32, f32, f32, f32, Vec) { + let dims = vector.dims(); + let vector = vector.slice(); + let sum_of_abs_x = S::reduce_sum_of_abs_x(vector); + let sum_of_x2 = S::reduce_sum_of_x2(vector); + let dis_u = sum_of_x2.sqrt(); + let x0 = sum_of_abs_x / (sum_of_x2 * (dims as f32)).sqrt(); + let x_x0 = dis_u / x0; + let fac_norm = (dims as f32).sqrt(); + let max_x1 = 1.0f32 / (dims as f32 - 1.0).sqrt(); + let factor_err = 2.0f32 * max_x1 * (x_x0 * x_x0 - dis_u * dis_u).sqrt(); + let factor_ip = -2.0f32 / fac_norm * x_x0; + let cnt_pos = vector + .iter() + .map(|x| x.scalar_is_sign_positive() as i32) + .sum::(); + let cnt_neg = vector + .iter() + .map(|x| x.scalar_is_sign_negative() as i32) + .sum::(); + let factor_ppc = factor_ip * (cnt_pos - cnt_neg) as f32; + let mut code = Vec::new(); + for i in 0..dims { + code.push(vector[i as usize].scalar_is_sign_positive() as u8); + } + (sum_of_x2, factor_ppc, factor_ip, factor_err, code) + } + + fn project(projection: &[Vec], vector: Borrowed<'_, Self>) -> Owned { + let slice = (0..projection.len()) + .map(|i| S::from_f32(S::reduce_sum_of_xy(&projection[i], vector.slice()))) + .collect(); + VectOwned::new(slice) + } + + type Lut = (f32, f32, f32, f32, (Vec, Vec, Vec, Vec)); + + fn preprocess(vector: Borrowed<'_, Self>) -> Self::Lut { + use crate::quantize; + let vector = vector.slice(); + let dis_v_2 = S::reduce_sum_of_x2(vector); + let (k, b, qvector) = quantize::quantize::<15>(S::vector_to_f32_borrowed(vector).as_ref()); + let qvector_sum = if vector.len() <= 4369 { + quantize::reduce_sum_of_x_as_u16(&qvector) as f32 + } else { + quantize::reduce_sum_of_x_as_u32(&qvector) as f32 + }; + let lut = binarize(&qvector); + (dis_v_2, b, k, qvector_sum, lut) + } + + fn process( + lut: &Self::Lut, + (_, factor_ppc, factor_ip, _, t): (f32, f32, f32, f32, &[u64]), + ) -> Distance { + let &(_, b, k, qvector_sum, ref s) = lut; + let value = asymmetric_binary_dot_product(t, s) as u16; + let rough = + 0.5 * b * factor_ppc + 0.5 * ((2.0 * value as f32) - qvector_sum) * factor_ip * k; + Distance::from_f32(rough) + } + + fn process_lowerbound( + lut: &Self::Lut, + (_, factor_ppc, factor_ip, factor_err, t): (f32, f32, f32, f32, &[u64]), + epsilon: f32, + ) -> Distance { + let &(dis_v_2, b, k, qvector_sum, ref s) = lut; + let value = asymmetric_binary_dot_product(t, s) as u16; + let rough = + 0.5 * b * factor_ppc + 0.5 * ((2.0 * value as f32) - qvector_sum) * factor_ip * k; + let err = 0.5 * factor_err * dis_v_2.sqrt(); + Distance::from_f32(rough - epsilon * err) + } + + type FLut = (f32, f32, f32, f32, Vec); + + fn fscan_preprocess(vector: Borrowed<'_, Self>) -> Self::FLut { + use crate::quantize; + let vector = vector.slice(); + let dis_v_2 = S::reduce_sum_of_x2(vector); + let (k, b, qvector) = quantize::quantize::<15>(S::vector_to_f32_borrowed(vector).as_ref()); + let qvector_sum = if vector.len() <= 4369 { + quantize::reduce_sum_of_x_as_u16(&qvector) as f32 + } else { + quantize::reduce_sum_of_x_as_u32(&qvector) as f32 + }; + let lut = gen(qvector); + (dis_v_2, b, k, qvector_sum, lut) + } + + fn fscan_process( + dims: u32, + lut: &Self::FLut, + (_, factor_ppc, factor_ip, _, t): (&[f32; 32], &[f32; 32], &[f32; 32], &[f32; 32], &[u8]), + ) -> [Distance; 32] { + let &(_, b, k, qvector_sum, ref s) = lut; + let r = fast_scan_b4(dims.div_ceil(4), t, s); + std::array::from_fn(|i| { + let rough = 0.5 * b * factor_ppc[i] + + 0.5 * ((2.0 * r[i] as f32) - qvector_sum) * factor_ip[i] * k; + Distance::from_f32(rough) + }) + } + + fn fscan_process_lowerbound( + dims: u32, + lut: &Self::FLut, + (_, factor_ppc, factor_ip, factor_err, t): ( + &[f32; 32], + &[f32; 32], + &[f32; 32], + &[f32; 32], + &[u8], + ), + epsilon: f32, + ) -> [Distance; 32] { + let &(dis_v_2, b, k, qvector_sum, ref s) = lut; + let r = fast_scan_b4(dims.div_ceil(4), t, s); + std::array::from_fn(|i| { + let rough = 0.5 * b * factor_ppc[i] + + 0.5 * ((2.0 * r[i] as f32) - qvector_sum) * factor_ip[i] * k; + let err = 0.5 * factor_err[i] * dis_v_2.sqrt(); + Distance::from_f32(rough - epsilon * err) + }) + } +} + +macro_rules! unimpl_operator_rabitq_quantization { + ($t:ty) => { + impl OperatorRabitqQuantization for $t { + type Scalar = Impossible; + + fn code(_: Borrowed<'_, Self>) -> (f32, f32, f32, f32, Vec) { + unimplemented!() + } + + fn project(_: &[Vec], _: Borrowed<'_, Self>) -> Owned { + unimplemented!() + } + + type Lut = std::convert::Infallible; + fn preprocess(_: Borrowed<'_, Self>) -> Self::Lut { + unimplemented!() + } + fn process(_: &Self::Lut, _: (f32, f32, f32, f32, &[u64])) -> Distance { + unimplemented!() + } + fn process_lowerbound( + _: &Self::Lut, + _: (f32, f32, f32, f32, &[u64]), + _: f32, + ) -> Distance { + unimplemented!() + } + + type FLut = std::convert::Infallible; + fn fscan_preprocess(_: Borrowed<'_, Self>) -> Self::FLut { + unimplemented!() + } + fn fscan_process( + _: u32, + _: &Self::Lut, + _: (&[f32; 32], &[f32; 32], &[f32; 32], &[f32; 32], &[u8]), + ) -> [Distance; 32] { + unimplemented!() + } + fn fscan_process_lowerbound( + _: u32, + _: &Self::Lut, + _: (&[f32; 32], &[f32; 32], &[f32; 32], &[f32; 32], &[u8]), + _: f32, + ) -> [Distance; 32] { + unimplemented!() + } + } + }; +} + +unimpl_operator_rabitq_quantization!(BVectorDot); +unimpl_operator_rabitq_quantization!(BVectorHamming); +unimpl_operator_rabitq_quantization!(BVectorJaccard); + +unimpl_operator_rabitq_quantization!(SVectDot); +unimpl_operator_rabitq_quantization!(SVectL2); + +fn parse_code(code: &[u8]) -> (f32, f32, f32, f32, &[u64]) { + assert!(code.len() > size_of::() * 4, "length is incorrect"); + assert!(code.len() % size_of::() == 0, "length is incorrect"); + assert!(code.as_ptr() as usize % 8 == 0, "pointer is not aligned"); + unsafe { + let a = code.as_ptr().add(0).cast::().read(); + let b = code.as_ptr().add(4).cast::().read(); + let c = code.as_ptr().add(8).cast::().read(); + let d = code.as_ptr().add(12).cast::().read(); + let e = std::slice::from_raw_parts(code[16..].as_ptr().cast(), code[16..].len() / 8); + (a, b, c, d, e) + } +} + +fn parses_codes(code: &[u8]) -> (&[f32; 32], &[f32; 32], &[f32; 32], &[f32; 32], &[u8]) { + assert!(code.len() > size_of::() * 128, "length is incorrect"); + assert!(code.as_ptr() as usize % 4 == 0, "pointer is not aligned"); + unsafe { + let a = &*code.as_ptr().add(0).cast::<[f32; 32]>(); + let b = &*code.as_ptr().add(128).cast::<[f32; 32]>(); + let c = &*code.as_ptr().add(256).cast::<[f32; 32]>(); + let d = &*code.as_ptr().add(384).cast::<[f32; 32]>(); + let e = &code[512..]; + (a, b, c, d, e) + } +} + +fn gen(mut qvector: Vec) -> Vec { + let dims = qvector.len() as u32; + let t = dims.div_ceil(4); + qvector.resize(qvector.len().next_multiple_of(4), 0); + let mut lut = vec![0u8; t as usize * 16]; + for i in 0..t as usize { + unsafe { + // this hint is used to skip bound checks + std::hint::assert_unchecked(4 * i + 3 < qvector.len()); + std::hint::assert_unchecked(16 * i + 15 < lut.len()); + } + let t0 = qvector[4 * i + 0]; + let t1 = qvector[4 * i + 1]; + let t2 = qvector[4 * i + 2]; + let t3 = qvector[4 * i + 3]; + lut[16 * i + 0b0000] = 0; + lut[16 * i + 0b0001] = t0; + lut[16 * i + 0b0010] = t1; + lut[16 * i + 0b0011] = t1 + t0; + lut[16 * i + 0b0100] = t2; + lut[16 * i + 0b0101] = t2 + t0; + lut[16 * i + 0b0110] = t2 + t1; + lut[16 * i + 0b0111] = t2 + t1 + t0; + lut[16 * i + 0b1000] = t3; + lut[16 * i + 0b1001] = t3 + t0; + lut[16 * i + 0b1010] = t3 + t1; + lut[16 * i + 0b1011] = t3 + t1 + t0; + lut[16 * i + 0b1100] = t3 + t2; + lut[16 * i + 0b1101] = t3 + t2 + t0; + lut[16 * i + 0b1110] = t3 + t2 + t1; + lut[16 * i + 0b1111] = t3 + t2 + t1 + t0; + } + lut +} + +fn binarize(vector: &[u8]) -> (Vec, Vec, Vec, Vec) { + let n = vector.len(); + let mut t0 = vec![0u64; n.div_ceil(64)]; + let mut t1 = vec![0u64; n.div_ceil(64)]; + let mut t2 = vec![0u64; n.div_ceil(64)]; + let mut t3 = vec![0u64; n.div_ceil(64)]; + for i in 0..n { + t0[i / 64] |= (((vector[i] >> 0) & 1) as u64) << (i % 64); + t1[i / 64] |= (((vector[i] >> 1) & 1) as u64) << (i % 64); + t2[i / 64] |= (((vector[i] >> 2) & 1) as u64) << (i % 64); + t3[i / 64] |= (((vector[i] >> 3) & 1) as u64) << (i % 64); + } + (t0, t1, t2, t3) +} + +#[detect::multiversion(v2, fallback)] +fn asymmetric_binary_dot_product(x: &[u64], y: &(Vec, Vec, Vec, Vec)) -> u32 { + assert_eq!(x.len(), y.0.len()); + assert_eq!(x.len(), y.1.len()); + assert_eq!(x.len(), y.2.len()); + assert_eq!(x.len(), y.3.len()); + let n = x.len(); + let (mut t0, mut t1, mut t2, mut t3) = (0, 0, 0, 0); + for i in 0..n { + t0 += (x[i] & y.0[i]).count_ones(); + } + for i in 0..n { + t1 += (x[i] & y.1[i]).count_ones(); + } + for i in 0..n { + t2 += (x[i] & y.2[i]).count_ones(); + } + for i in 0..n { + t3 += (x[i] & y.3[i]).count_ones(); + } + (t0 << 0) + (t1 << 1) + (t2 << 2) + (t3 << 3) +} diff --git a/crates/rabitq/src/quant/error.rs b/crates/quantization/src/reranker/error.rs similarity index 100% rename from crates/rabitq/src/quant/error.rs rename to crates/quantization/src/reranker/error.rs diff --git a/crates/quantization/src/reranker/mod.rs b/crates/quantization/src/reranker/mod.rs index 17dabb167..2576d564c 100644 --- a/crates/quantization/src/reranker/mod.rs +++ b/crates/quantization/src/reranker/mod.rs @@ -1,3 +1,4 @@ +pub mod error; pub mod flat; pub mod graph; pub mod graph_2; diff --git a/crates/quantization/src/scalar.rs b/crates/quantization/src/scalar.rs index f6745121a..eea2476aa 100644 --- a/crates/quantization/src/scalar.rs +++ b/crates/quantization/src/scalar.rs @@ -143,6 +143,10 @@ impl Quantizer for ScalarQuantizer { } } + fn project(&self, vector: Borrowed<'_, O>) -> Owned { + vector.own() + } + type Lut = Vec; fn preprocess(&self, vector: Borrowed<'_, O>) -> Self::Lut { @@ -164,7 +168,7 @@ impl Quantizer for ScalarQuantizer { O::fscan_preprocess(self.dims, self.bits, &self.max, &self.min, vector) } - fn fscan_process(flut: &Self::FLut, code: &[u8]) -> [Distance; 32] { + fn fscan_process(&self, flut: &Self::FLut, code: &[u8]) -> [Distance; 32] { O::fscan_process(flut, code) } @@ -211,19 +215,27 @@ impl Quantizer for ScalarQuantizer { match frlut { Ok(flut) => { fn divide(r: Range) -> (Option, Range, Option) { - if r.start > r.end || r.start % 32 == 0 && r.end % 32 == 0 { - (None, r.start / 32..r.end / 32, None) - } else if r.start / 32 == r.end / 32 { - (Some(r.start / 32), 0..0, None) - } else { - let left = (r.start % 32 != 0).then_some(r.start / 32); - let right = (r.end % 32 != 0).then_some(r.end / 32); - (left, r.start / 32 + 1..r.end / 32, right) + if r.start > r.end { + return (None, r.start / 32..r.end / 32, None); } + if r.start / 32 == r.end / 32 { + return (Some(r.start / 32), 0..0, None); + }; + let left = if r.start % 32 == 0 { + (None, r.start / 32) + } else { + (Some(r.start / 32), r.start / 32 + 1) + }; + let right = if r.end % 32 == 0 { + (r.end / 32, None) + } else { + (r.end / 32, Some(r.end / 32)) + }; + (left.0, left.1..right.0, right.1) } let (left, main, right) = divide(range.clone()); if let Some(i) = left { - let r = Self::fscan_process(flut, locate_1(i).as_ref()); + let r = self.fscan_process(flut, locate_1(i).as_ref()); for j in 0..32 { if range.contains(&(i * 32 + j)) { heap.push((Reverse(r[j as usize]), AlwaysEqual(i * 32 + j))); @@ -231,13 +243,13 @@ impl Quantizer for ScalarQuantizer { } } for i in main { - let r = Self::fscan_process(flut, locate_1(i).as_ref()); + let r = self.fscan_process(flut, locate_1(i).as_ref()); for j in 0..32 { heap.push((Reverse(r[j as usize]), AlwaysEqual(i * 32 + j))); } } if let Some(i) = right { - let r = Self::fscan_process(flut, locate_1(i).as_ref()); + let r = self.fscan_process(flut, locate_1(i).as_ref()); for j in 0..32 { if range.contains(&(i * 32 + j)) { heap.push((Reverse(r[j as usize]), AlwaysEqual(i * 32 + j))); @@ -256,8 +268,8 @@ impl Quantizer for ScalarQuantizer { fn graph_rerank<'a, T, R, C>( &'a self, + lut: Self::Lut, locate: impl Fn(u32) -> C + 'a, - vector: Borrowed<'a, O>, rerank: R, ) -> impl RerankerPush + RerankerPop + 'a where @@ -265,9 +277,8 @@ impl Quantizer for ScalarQuantizer { R: Fn(u32) -> (Distance, T) + 'a, C: AsRef<[u8]>, { - let lut = self.preprocess(vector); Graph2Reranker::new( - move |u| self.process(&lut, locate(u).as_ref(), vector), + move |u| O::process(self.dims, self.bits, &lut, locate(u).as_ref()), rerank, ) } diff --git a/crates/quantization/src/trivial.rs b/crates/quantization/src/trivial.rs index 0094cb816..076fb93c9 100644 --- a/crates/quantization/src/trivial.rs +++ b/crates/quantization/src/trivial.rs @@ -48,6 +48,10 @@ impl Quantizer for TrivialQuantizer { 0 } + fn project(&self, vector: Borrowed<'_, O>) -> Owned { + vector.own() + } + type Lut = Owned; fn preprocess(&self, vector: Borrowed<'_, O>) -> Self::Lut { @@ -64,7 +68,7 @@ impl Quantizer for TrivialQuantizer { unimplemented!() } - fn fscan_process(_: &Self::FLut, _: &[u8]) -> [Distance; 32] { + fn fscan_process(&self, _: &Self::FLut, _: &[u8]) -> [Distance; 32] { unimplemented!() } @@ -114,8 +118,8 @@ impl Quantizer for TrivialQuantizer { fn graph_rerank<'a, T, R, C>( &'a self, + _: Self::Lut, _: impl Fn(u32) -> C + 'a, - _: Borrowed<'a, O>, rerank: R, ) -> impl RerankerPush + RerankerPop + 'a where diff --git a/crates/rabitq/Cargo.toml b/crates/rabitq/Cargo.toml deleted file mode 100644 index be20863e4..000000000 --- a/crates/rabitq/Cargo.toml +++ /dev/null @@ -1,24 +0,0 @@ -[package] -name = "rabitq" -version.workspace = true -edition.workspace = true - -[dependencies] -half.workspace = true -log.workspace = true -rand.workspace = true -rand_distr.workspace = true -serde.workspace = true -serde_json.workspace = true - -base = { path = "../base" } -common = { path = "../common" } -detect = { path = "../detect" } -k_means = { version = "0.0.0", path = "../k_means" } -nalgebra = { version = "0.33.0", features = ["debug"] } -quantization = { path = "../quantization" } -stoppable_rayon = { path = "../stoppable_rayon" } -storage = { version = "0.0.0", path = "../storage" } - -[lints] -workspace = true diff --git a/crates/rabitq/src/lib.rs b/crates/rabitq/src/lib.rs deleted file mode 100644 index a90433b18..000000000 --- a/crates/rabitq/src/lib.rs +++ /dev/null @@ -1,255 +0,0 @@ -#![allow(clippy::needless_range_loop)] -#![allow(clippy::type_complexity)] -#![allow(clippy::identity_op)] -#![allow(clippy::too_many_arguments)] -#![allow(clippy::len_without_is_empty)] - -pub mod operator; -pub mod quant; - -use crate::operator::OperatorRabitq as Op; -use crate::quant::quantization::Quantization; -use base::always_equal::AlwaysEqual; -use base::index::{IndexOptions, RabitqIndexingOptions, SearchOptions}; -use base::operator::{Borrowed, Owned}; -use base::search::{Collection, Element, Payload, RerankerPop, Source, Vectors}; -use common::json::Json; -use common::mmap_array::MmapArray; -use common::remap::RemappedCollection; -use common::vec2::Vec2; -use k_means::{k_means, k_means_lookup, k_means_lookup_many}; -use rayon::iter::{IntoParallelIterator, ParallelIterator}; -use std::fs::create_dir; -use std::path::Path; -use stoppable_rayon as rayon; -use storage::Storage; - -pub struct Rabitq { - storage: O::Storage, - quantization: Quantization, - payloads: MmapArray, - offsets: Json>, - projected_centroids: Json>, - projection: Json>>, - is_residual: Json, -} - -impl Rabitq { - pub fn create( - path: impl AsRef, - options: IndexOptions, - source: &(impl Vectors> + Collection + Source + Sync), - ) -> Self { - let remapped = RemappedCollection::from_source(source); - from_nothing(path, options, &remapped) - } - - pub fn open(path: impl AsRef) -> Self { - open(path) - } - - pub fn dims(&self) -> u32 { - self.storage.dims() - } - - pub fn len(&self) -> u32 { - self.storage.len() - } - - pub fn vector(&self, i: u32) -> Borrowed<'_, O> { - self.storage.vector(i) - } - - pub fn payload(&self, i: u32) -> Payload { - self.payloads[i as usize] - } - - pub fn vbase<'a>( - &'a self, - vector: Borrowed<'a, O>, - opts: &'a SearchOptions, - ) -> Box + 'a> { - let projected_query = O::proj(&self.projection, O::cast(vector)); - let lists = select( - k_means_lookup_many(&projected_query, &self.projected_centroids), - opts.rabitq_nprobe as usize, - ); - let mut heap = Vec::new(); - for &(dis_v2, i) in lists.iter() { - let trans_vector = if *self.is_residual { - &O::residual(&projected_query, &self.projected_centroids[(i,)]) - } else { - &projected_query - }; - let preprocessed = if opts.rabitq_fast_scan { - self.quantization.fscan_preprocess(trans_vector, dis_v2) - } else { - self.quantization.preprocess(trans_vector, dis_v2) - }; - let start = self.offsets[i]; - let end = self.offsets[i + 1]; - self.quantization - .push_batch(&preprocessed, start..end, &mut heap, opts.rabitq_epsilon); - } - let mut reranker = self.quantization.rerank(heap, move |u| { - (O::distance(vector, self.storage.vector(u)), ()) - }); - Box::new(std::iter::from_fn(move || { - reranker.pop().map(|(dis_u, u, ())| Element { - distance: dis_u, - payload: AlwaysEqual(self.payload(u)), - }) - })) - } -} - -fn from_nothing( - path: impl AsRef, - options: IndexOptions, - collection: &(impl Vectors> + Collection + Sync), -) -> Rabitq { - create_dir(path.as_ref()).unwrap(); - let RabitqIndexingOptions { - nlist, - spherical_centroids, - residual_quantization, - } = options.indexing.clone().unwrap_rabitq(); - let projection = { - use nalgebra::{DMatrix, QR}; - use rand::Rng; - use rand_distr::StandardNormal; - let mut rng = rand::thread_rng(); - let dims = options.vector.dims; - let matrix: Vec = (0..dims as usize * dims as usize) - .map(|_| rng.sample(StandardNormal)) - .collect(); - let matrix = DMatrix::from_vec(dims as usize, dims as usize, matrix); - let qr = QR::new(matrix); - let q = qr.q(); - let mut projection = vec![Vec::with_capacity(dims as usize); dims as usize]; - for (i, v) in q.row_iter().enumerate() { - for &val in v.iter() { - projection[i].push(val); - } - } - projection - }; - let is_residual = residual_quantization && O::SUPPORT_RESIDUAL; - rayon::check(); - let samples = O::sample(collection, nlist); - rayon::check(); - let centroids: Vec2 = k_means(nlist as usize, samples, true, spherical_centroids, false); - rayon::check(); - let ls = (0..collection.len()) - .into_par_iter() - .fold( - || vec![Vec::new(); nlist as usize], - |mut state, i| { - state[k_means_lookup(O::cast(collection.vector(i)), ¢roids)].push(i); - state - }, - ) - .reduce( - || vec![Vec::new(); nlist as usize], - |lhs, rhs| { - std::iter::zip(lhs, rhs) - .map(|(lhs, rhs)| { - let mut x = lhs; - x.extend(rhs); - x - }) - .collect() - }, - ); - let mut offsets = vec![0u32; nlist as usize + 1]; - for i in 0..nlist { - offsets[i as usize + 1] = offsets[i as usize] + ls[i as usize].len() as u32; - } - let remap = ls - .into_iter() - .flat_map(|x| x.into_iter()) - .collect::>(); - let collection = RemappedCollection::from_collection(collection, remap); - rayon::check(); - let storage = O::Storage::create(path.as_ref().join("storage"), &collection); - - let quantization = if is_residual { - Quantization::create( - path.as_ref().join("quantization"), - options.vector, - collection.len(), - |vector| { - let vector = O::cast(collection.vector(vector)); - let target = k_means_lookup(vector, ¢roids); - O::proj(&projection, &O::residual(vector, ¢roids[(target,)])) - }, - ) - } else { - Quantization::create( - path.as_ref().join("quantization"), - options.vector, - collection.len(), - |vector| { - let vector = O::cast(collection.vector(vector)); - O::proj(&projection, vector) - }, - ) - }; - - let projected_centroids = Vec2::from_vec( - (centroids.shape_0(), centroids.shape_1()), - (0..centroids.shape_0()) - .flat_map(|x| O::proj(&projection, ¢roids[(x,)])) - .collect(), - ); - let payloads = MmapArray::create( - path.as_ref().join("payloads"), - (0..collection.len()).map(|i| collection.payload(i)), - ); - let offsets = Json::create(path.as_ref().join("offsets"), offsets); - let projected_centroids = Json::create( - path.as_ref().join("projected_centroids"), - projected_centroids, - ); - let projection = Json::create(path.as_ref().join("projection"), projection); - let is_residual = Json::create(path.as_ref().join("is_residual"), is_residual); - Rabitq { - storage, - payloads, - offsets, - projected_centroids, - quantization, - projection, - is_residual, - } -} - -fn open(path: impl AsRef) -> Rabitq { - let storage = O::Storage::open(path.as_ref().join("storage")); - let quantization = Quantization::open(path.as_ref().join("quantization")); - let payloads = MmapArray::open(path.as_ref().join("payloads")); - let offsets = Json::open(path.as_ref().join("offsets")); - let projected_centroids = Json::open(path.as_ref().join("projected_centroids")); - let projection = Json::open(path.as_ref().join("projection")); - let is_residual = Json::open(path.as_ref().join("is_residual")); - Rabitq { - storage, - quantization, - payloads, - offsets, - projected_centroids, - projection, - is_residual, - } -} - -fn select(mut lists: Vec<(f32, T)>, n: usize) -> Vec<(f32, T)> { - if lists.is_empty() || n == 0 { - return Vec::new(); - } - let n = n.min(lists.len()); - lists.select_nth_unstable_by(n - 1, |(x, _), (y, _)| f32::total_cmp(x, y)); - lists.truncate(n); - lists.sort_by(|(x, _), (y, _)| f32::total_cmp(x, y)); - lists -} diff --git a/crates/rabitq/src/operator.rs b/crates/rabitq/src/operator.rs deleted file mode 100644 index fb1db69e0..000000000 --- a/crates/rabitq/src/operator.rs +++ /dev/null @@ -1,694 +0,0 @@ -use std::ops::Index; - -use base::distance::Distance; -use base::operator::Borrowed; -use base::operator::*; -use base::scalar::ScalarLike; -use base::search::Vectors; -use common::vec2::Vec2; -use half::f16; -use storage::OperatorStorage; - -pub trait OperatorRabitq: OperatorStorage { - fn sample(vectors: &impl Vectors, nlist: u32) -> Vec2; - fn cast(vector: Borrowed<'_, Self>) -> &[f32]; - - const SUPPORT_RESIDUAL: bool; - fn residual(lhs: &[f32], rhs: &[f32]) -> Vec; - fn proj(projection: &[Vec], vector: &[f32]) -> Vec; - - type VectorParams: IntoIterator; - type QvectorParams; - type QvectorLookup; - - fn train_encode(dims: u32, vector: Vec) -> Self::VectorParams; - fn train_decode + ?Sized>(u: u32, meta: &T) - -> Self::VectorParams; - fn preprocess(trans_vector: &[f32], dis_v_2: f32) - -> (Self::QvectorParams, Self::QvectorLookup); - fn process( - vector_params: &Self::VectorParams, - qvector_code: &[u8], - qvector_params: &Self::QvectorParams, - qvector_lookup: &Self::QvectorLookup, - ) -> Distance; - fn process_lowerbound( - vector_params: &Self::VectorParams, - qvector_code: &[u8], - qvector_params: &Self::QvectorParams, - qvector_lookup: &Self::QvectorLookup, - epsilon: f32, - ) -> Distance; - fn fscan_preprocess(trans_vector: &[f32], dis_v_2: f32) -> (Self::QvectorParams, Vec); - fn fscan_process_lowerbound( - vector_params: &Self::VectorParams, - qvector_params: &Self::QvectorParams, - binary_prod: u16, - epsilon: f32, - ) -> Distance; -} - -impl OperatorRabitq for VectL2 { - fn sample(vectors: &impl Vectors, nlist: u32) -> Vec2 { - common::sample::sample( - vectors.len(), - nlist.saturating_mul(256).min(1 << 20), - vectors.dims(), - |i| vectors.vector(i).slice(), - ) - } - fn cast(vector: Borrowed<'_, Self>) -> &[f32] { - vector.slice() - } - const SUPPORT_RESIDUAL: bool = true; - fn residual(lhs: &[f32], rhs: &[f32]) -> Vec { - f32::vector_sub(lhs, rhs) - } - fn proj(projection: &[Vec], vector: &[f32]) -> Vec { - let dims = vector.len(); - assert_eq!(projection.len(), dims); - (0..dims) - .map(|i| f32::reduce_sum_of_xy(&projection[i], vector)) - .collect() - } - - // [dis_u_2, factor_ppc, factor_ip, factor_err] - type VectorParams = [f32; 4]; - // (dis_v_2, b, k, qvector_sum) - type QvectorParams = (f32, f32, f32, f32); - type QvectorLookup = ((Vec, Vec, Vec, Vec), Vec); - - fn train_encode(dims: u32, vector: Vec) -> Self::VectorParams { - let sum_of_abs_x = f32::reduce_sum_of_abs_x(&vector); - let dis_u_2 = f32::reduce_sum_of_x2(&vector); - let dis_u = dis_u_2.sqrt(); - let x0 = sum_of_abs_x / (dis_u_2 * (dims as f32)).sqrt(); - let x_x0 = dis_u / x0; - let fac_norm = (dims as f32).sqrt(); - let max_x1 = 1.0f32 / (dims as f32 - 1.0).sqrt(); - let factor_err = 2.0f32 * max_x1 * (x_x0 * x_x0 - dis_u * dis_u).sqrt(); - let factor_ip = -2.0f32 / fac_norm * x_x0; - let cnt_pos = vector - .iter() - .map(|x| x.is_sign_positive() as i32) - .sum::(); - let cnt_neg = vector - .iter() - .map(|x| x.is_sign_negative() as i32) - .sum::(); - let factor_ppc = factor_ip * (cnt_pos - cnt_neg) as f32; - [dis_u_2, factor_ppc, factor_ip, factor_err] - } - - fn train_decode + ?Sized>( - u: u32, - meta: &T, - ) -> Self::VectorParams { - let dis_u_2 = meta[4 * u as usize + 0]; - let factor_ppc = meta[4 * u as usize + 1]; - let factor_ip = meta[4 * u as usize + 2]; - let factor_err = meta[4 * u as usize + 3]; - [dis_u_2, factor_ppc, factor_ip, factor_err] - } - - fn preprocess( - trans_vector: &[f32], - dis_v_2: f32, - ) -> (Self::QvectorParams, Self::QvectorLookup) { - use quantization::quantize; - let (k, b, qvector) = quantize::quantize::<15>(trans_vector); - let qvector_sum = if trans_vector.len() <= 4369 { - quantize::reduce_sum_of_x_as_u16(&qvector) as f32 - } else { - quantize::reduce_sum_of_x_as_u32(&qvector) as f32 - }; - - let blut = binarize(&qvector); - let lut = gen(qvector); - ((dis_v_2, b, k, qvector_sum), (blut, lut)) - } - - fn process( - vector_params: &Self::VectorParams, - qvector_code: &[u8], - qvector_params: &Self::QvectorParams, - qvector_lookup: &Self::QvectorLookup, - ) -> Distance { - let (blut, _) = qvector_lookup; - let binary_prod = asymmetric_binary_dot_product(qvector_code, blut) as u16; - let (dis_u_2, factor_ppc, factor_ip, factor_err) = match vector_params { - [a, b, c, d] => (*a, *b, *c, *d), - }; - let (rough, _) = rabitq_l2( - dis_u_2, - factor_ppc, - factor_ip, - factor_err, - *qvector_params, - binary_prod, - ); - Distance::from_f32(rough) - } - - fn process_lowerbound( - vector_params: &Self::VectorParams, - qvector_code: &[u8], - qvector_params: &Self::QvectorParams, - qvector_lookup: &Self::QvectorLookup, - epsilon: f32, - ) -> Distance { - let (blut, _) = qvector_lookup; - let binary_prod = asymmetric_binary_dot_product(qvector_code, blut) as u16; - let (dis_u_2, factor_ppc, factor_ip, factor_err) = match vector_params { - [a, b, c, d] => (*a, *b, *c, *d), - }; - let (rough, err) = rabitq_l2( - dis_u_2, - factor_ppc, - factor_ip, - factor_err, - *qvector_params, - binary_prod, - ); - Distance::from_f32(rough - epsilon * err) - } - - fn fscan_preprocess(trans_vector: &[f32], dis_v_2: f32) -> (Self::QvectorParams, Vec) { - use quantization::quantize; - let (k, b, qvector) = quantize::quantize::<15>(trans_vector); - let qvector_sum = if trans_vector.len() <= 4369 { - quantize::reduce_sum_of_x_as_u16(&qvector) as f32 - } else { - quantize::reduce_sum_of_x_as_u32(&qvector) as f32 - }; - let lut = gen(qvector); - ((dis_v_2, b, k, qvector_sum), lut) - } - - fn fscan_process_lowerbound( - vector_params: &Self::VectorParams, - qvector_params: &Self::QvectorParams, - binary_prod: u16, - epsilon: f32, - ) -> Distance { - let (dis_u_2, factor_ppc, factor_ip, factor_err) = match vector_params { - [a, b, c, d] => (*a, *b, *c, *d), - }; - let (rough, err) = rabitq_l2( - dis_u_2, - factor_ppc, - factor_ip, - factor_err, - *qvector_params, - binary_prod, - ); - Distance::from_f32(rough - epsilon * err) - } -} - -impl OperatorRabitq for VectDot { - fn sample(vectors: &impl Vectors, nlist: u32) -> Vec2 { - VectL2::::sample(vectors, nlist) - } - fn cast(vector: Borrowed<'_, Self>) -> &[f32] { - VectL2::::cast(vector) - } - const SUPPORT_RESIDUAL: bool = false; - fn residual(_lhs: &[f32], _rhs: &[f32]) -> Vec { - unimplemented!() - } - fn proj(projection: &[Vec], vector: &[f32]) -> Vec { - VectL2::::proj(projection, vector) - } - - // [factor_ppc, factor_ip, factor_err] - type VectorParams = [f32; 3]; - // (dis_v_2, b, k, qvector_sum) - type QvectorParams = (f32, f32, f32, f32); - type QvectorLookup = ((Vec, Vec, Vec, Vec), Vec); - - fn train_encode(dims: u32, vector: Vec) -> Self::VectorParams { - let (factor_ppc, factor_ip, factor_err) = match VectL2::::train_encode(dims, vector) { - [_, b, c, d] => (b, c, d), - }; - - [factor_ppc, factor_ip, factor_err] - } - - fn train_decode + ?Sized>( - u: u32, - meta: &T, - ) -> Self::VectorParams { - let factor_ppc = meta[4 * u as usize + 0]; - let factor_ip = meta[4 * u as usize + 1]; - let factor_err = meta[4 * u as usize + 2]; - [factor_ppc, factor_ip, factor_err] - } - - fn preprocess( - trans_vector: &[f32], - dis_v_2: f32, - ) -> (Self::QvectorParams, Self::QvectorLookup) { - VectL2::::preprocess(trans_vector, dis_v_2) - } - - fn process( - vector_params: &Self::VectorParams, - qvector_code: &[u8], - qvector_params: &Self::QvectorParams, - qvector_lookup: &Self::QvectorLookup, - ) -> Distance { - let (blut, _) = qvector_lookup; - let binary_prod = asymmetric_binary_dot_product(qvector_code, blut) as u16; - let (factor_ppc, factor_ip, factor_err) = match vector_params { - [a, b, c] => (*a, *b, *c), - }; - let (rough, _) = rabitq_dot( - factor_ppc, - factor_ip, - factor_err, - *qvector_params, - binary_prod, - ); - Distance::from_f32(rough) - } - - fn process_lowerbound( - vector_params: &Self::VectorParams, - qvector_code: &[u8], - qvector_params: &Self::QvectorParams, - qvector_lookup: &Self::QvectorLookup, - epsilon: f32, - ) -> Distance { - let (blut, _) = qvector_lookup; - let binary_prod = asymmetric_binary_dot_product(qvector_code, blut) as u16; - let (factor_ppc, factor_ip, factor_err) = match vector_params { - [a, b, c] => (*a, *b, *c), - }; - let (rough, err) = rabitq_dot( - factor_ppc, - factor_ip, - factor_err, - *qvector_params, - binary_prod, - ); - Distance::from_f32(rough - epsilon * err) - } - fn fscan_preprocess(trans_vector: &[f32], dis_v_2: f32) -> (Self::QvectorParams, Vec) { - VectL2::::fscan_preprocess(trans_vector, dis_v_2) - } - - fn fscan_process_lowerbound( - vector_params: &Self::VectorParams, - qvector_params: &Self::QvectorParams, - binary_prod: u16, - epsilon: f32, - ) -> Distance { - let (factor_ppc, factor_ip, factor_err) = match vector_params { - [a, b, c] => (*a, *b, *c), - }; - let (rough, err) = rabitq_dot( - factor_ppc, - factor_ip, - factor_err, - *qvector_params, - binary_prod, - ); - Distance::from_f32(rough - epsilon * err) - } -} - -macro_rules! unimpl_operator_rabitq { - ($t:ty) => { - impl OperatorRabitq for $t { - fn sample(_: &impl Vectors, _: u32) -> Vec2 { - unimplemented!() - } - - fn cast(_: Borrowed<'_, Self>) -> &[f32] { - unimplemented!() - } - - const SUPPORT_RESIDUAL: bool = false; - fn residual(_: &[f32], _: &[f32]) -> Vec { - unimplemented!() - } - - fn proj(_: &[Vec], _: &[f32]) -> Vec { - unimplemented!() - } - - type VectorParams = [f32; 0]; - type QvectorParams = std::convert::Infallible; - type QvectorLookup = std::convert::Infallible; - - fn train_encode(_: u32, _: Vec) -> Self::VectorParams { - unimplemented!() - } - - fn train_decode + ?Sized>( - _: u32, - _: &T, - ) -> Self::VectorParams { - unimplemented!() - } - - fn preprocess(_: &[f32], _: f32) -> (Self::QvectorParams, Self::QvectorLookup) { - unimplemented!() - } - - fn process( - _: &Self::VectorParams, - _: &[u8], - _: &Self::QvectorParams, - _: &Self::QvectorLookup, - ) -> Distance { - unimplemented!() - } - - fn process_lowerbound( - _: &Self::VectorParams, - _: &[u8], - _: &Self::QvectorParams, - _: &Self::QvectorLookup, - _: f32, - ) -> Distance { - unimplemented!() - } - - fn fscan_preprocess(_: &[f32], _: f32) -> (Self::QvectorLookup, Vec) { - unimplemented!() - } - fn fscan_process_lowerbound( - _: &Self::VectorParams, - _: &Self::QvectorParams, - _: u16, - _: f32, - ) -> Distance { - unimplemented!() - } - } - }; -} - -unimpl_operator_rabitq!(VectDot); -unimpl_operator_rabitq!(VectL2); - -unimpl_operator_rabitq!(BVectorDot); -unimpl_operator_rabitq!(BVectorHamming); -unimpl_operator_rabitq!(BVectorJaccard); - -unimpl_operator_rabitq!(SVectDot); -unimpl_operator_rabitq!(SVectL2); - -#[inline(always)] -pub fn rabitq_l2( - dis_u_2: f32, - factor_ppc: f32, - factor_ip: f32, - factor_err: f32, - (dis_v_2, b, k, qvector_sum): (f32, f32, f32, f32), - binary_prod: u16, -) -> (f32, f32) { - let rough = dis_u_2 - + dis_v_2 - + b * factor_ppc - + ((2.0 * binary_prod as f32) - qvector_sum) * factor_ip * k; - let err = factor_err * dis_v_2.sqrt(); - (rough, err) -} - -#[inline(always)] -pub fn rabitq_dot( - factor_ppc: f32, - factor_ip: f32, - factor_err: f32, - (dis_v_2, b, k, qvector_sum): (f32, f32, f32, f32), - binary_prod: u16, -) -> (f32, f32) { - let rough = - 0.5 * b * factor_ppc + 0.5 * ((2.0 * binary_prod as f32) - qvector_sum) * factor_ip * k; - let err = factor_err * dis_v_2.sqrt() * 0.5; - (rough, err) -} - -fn binarize(vector: &[u8]) -> (Vec, Vec, Vec, Vec) { - let n = vector.len(); - let t0 = { - let mut t = vec![0u8; n.div_ceil(8)]; - for i in 0..n { - t[i / 8] |= ((vector[i] >> 0) & 1) << (i % 8); - } - t - }; - let t1 = { - let mut t = vec![0u8; n.div_ceil(8)]; - for i in 0..n { - t[i / 8] |= ((vector[i] >> 1) & 1) << (i % 8); - } - t - }; - let t2 = { - let mut t = vec![0u8; n.div_ceil(8)]; - for i in 0..n { - t[i / 8] |= ((vector[i] >> 2) & 1) << (i % 8); - } - t - }; - let t3 = { - let mut t = vec![0u8; n.div_ceil(8)]; - for i in 0..n { - t[i / 8] |= ((vector[i] >> 3) & 1) << (i % 8); - } - t - }; - (t0, t1, t2, t3) -} - -fn gen(mut qvector: Vec) -> Vec { - let dims = qvector.len() as u32; - let t = dims.div_ceil(4); - qvector.resize(qvector.len().next_multiple_of(4), 0); - let mut lut = vec![0u8; t as usize * 16]; - for i in 0..t as usize { - unsafe { - // this hint is used to skip bound checks - std::hint::assert_unchecked(4 * i + 3 < qvector.len()); - std::hint::assert_unchecked(16 * i + 15 < lut.len()); - } - let t0 = qvector[4 * i + 0]; - let t1 = qvector[4 * i + 1]; - let t2 = qvector[4 * i + 2]; - let t3 = qvector[4 * i + 3]; - lut[16 * i + 0b0000] = 0; - lut[16 * i + 0b0001] = t0; - lut[16 * i + 0b0010] = t1; - lut[16 * i + 0b0011] = t1 + t0; - lut[16 * i + 0b0100] = t2; - lut[16 * i + 0b0101] = t2 + t0; - lut[16 * i + 0b0110] = t2 + t1; - lut[16 * i + 0b0111] = t2 + t1 + t0; - lut[16 * i + 0b1000] = t3; - lut[16 * i + 0b1001] = t3 + t0; - lut[16 * i + 0b1010] = t3 + t1; - lut[16 * i + 0b1011] = t3 + t1 + t0; - lut[16 * i + 0b1100] = t3 + t2; - lut[16 * i + 0b1101] = t3 + t2 + t0; - lut[16 * i + 0b1110] = t3 + t2 + t1; - lut[16 * i + 0b1111] = t3 + t2 + t1 + t0; - } - lut -} - -fn binary_dot_product(x: &[u8], y: &[u8]) -> u32 { - assert_eq!(x.len(), y.len()); - let n = x.len(); - let mut res = 0; - for i in 0..n { - res += (x[i] & y[i]).count_ones(); - } - res -} - -fn asymmetric_binary_dot_product(x: &[u8], y: &(Vec, Vec, Vec, Vec)) -> u32 { - let mut res = 0; - res += binary_dot_product(x, &y.0) << 0; - res += binary_dot_product(x, &y.1) << 1; - res += binary_dot_product(x, &y.2) << 2; - res += binary_dot_product(x, &y.3) << 3; - res -} - -#[cfg(test)] -mod test { - use super::*; - use common::mmap_array::MmapArray; - use quantization::utils::InfiniteByteChunks; - use rand::{thread_rng, Rng}; - use std::{env, sync::LazyLock}; - - const EPSILON: f32 = 2.9; - const LENGTH: usize = 128; - const ATTEMPTS: usize = 10000; - - struct Case { - original: Vec, - centroid: Vec, - trans_vector: Vec, - } - - static PREPROCESS_O: LazyLock = LazyLock::new(|| { - let original: Vec = [(); LENGTH] - .into_iter() - .map(|_| thread_rng().gen_range((-1.0 * LENGTH as f32)..(LENGTH as f32))) - .collect(); - let centroid: Vec = vec![0.0; LENGTH].into_iter().collect(); - Case { - original: original.clone(), - centroid: centroid.clone(), - trans_vector: VectL2::::residual(&original, ¢roid), - } - }); - - #[test] - fn vector_f32l2_encode_decode() { - let path = env::temp_dir().join("meta_l2"); - let _ = std::fs::remove_file(path.clone()); - let case = &*PREPROCESS_O; - - let meta = - VectL2::::train_encode(case.trans_vector.len() as u32, case.trans_vector.clone()); - let mmap = MmapArray::create(path.clone(), Box::new(meta.into_iter())); - let params = VectL2::::train_decode(0, &mmap); - assert_eq!( - meta, params, - "Vecf32L2 encode and decode failed {:?} != {:?}", - meta, params - ); - std::fs::remove_file(path.clone()).unwrap(); - } - - #[test] - fn vector_f32dot_encode_decode() { - let path = env::temp_dir().join("meta_dot"); - let _ = std::fs::remove_file(path.clone()); - let case = &*PREPROCESS_O; - - let meta = - VectDot::::train_encode(case.trans_vector.len() as u32, case.trans_vector.clone()); - let mmap = MmapArray::create(path.clone(), Box::new(meta.into_iter())); - let params = VectDot::::train_decode(0, &mmap); - assert_eq!( - meta, params, - "Vecf32Dot encode and decode failed {:?} != {:?}", - meta, params - ); - std::fs::remove_file(path.clone()).unwrap(); - } - - #[test] - fn vector_f32l2_no_residual_estimate() { - let mut bad: usize = 0; - let case = &*PREPROCESS_O; - for _ in 0..ATTEMPTS { - let (query, trans_vector, dis_v_2, codes, estimate_failed) = - estimate_prepare_query(&case.centroid); - - let vector_params = VectL2::::train_encode( - case.trans_vector.len() as u32, - case.trans_vector.clone(), - ); - let (qvector_params, qvector_lookup) = - VectL2::::preprocess(&trans_vector, dis_v_2); - let est = - VectL2::::process(&vector_params, &codes, &qvector_params, &qvector_lookup); - let b = VectL2::::process_lowerbound( - &vector_params, - &codes, - &qvector_params, - &qvector_lookup, - EPSILON, - ); - - let real = f32::reduce_sum_of_d2(&query, &case.original); - if estimate_failed(est.to_f32(), b.to_f32(), real) { - bad += 1; - } - } - let error_rate = (bad as f32) / (ATTEMPTS as f32); - assert!( - error_rate < 0.02, - "too many errors: {} in {}", - bad, - ATTEMPTS, - ); - } - - #[test] - fn vector_f32dot_no_residual_estimate() { - let mut bad: usize = 0; - let case = &*PREPROCESS_O; - for _ in 0..ATTEMPTS { - let (query, trans_vector, dis_v_2, codes, estimate_failed) = - estimate_prepare_query(&case.centroid); - - let vector_params = VectDot::::train_encode( - case.trans_vector.len() as u32, - case.trans_vector.clone(), - ); - let (qvector_params, qvector_lookup) = - VectDot::::preprocess(&trans_vector, dis_v_2); - let est = - VectDot::::process(&vector_params, &codes, &qvector_params, &qvector_lookup); - let b = VectDot::::process_lowerbound( - &vector_params, - &codes, - &qvector_params, - &qvector_lookup, - EPSILON, - ); - - let real = -f32::reduce_sum_of_xy(&query, &case.original); - if estimate_failed(est.to_f32(), b.to_f32(), real) { - bad += 1; - } - } - let error_rate = (bad as f32) / (ATTEMPTS as f32); - assert!( - error_rate < 0.02, - "too many errors: {} in {}", - bad, - ATTEMPTS, - ); - } - - fn estimate_prepare_query( - centroid: &Vec, - ) -> ( - Vec, - Vec, - f32, - Vec, - impl Fn(f32, f32, f32) -> bool, - ) { - fn merge_8([b0, b1, b2, b3, b4, b5, b6, b7]: [u8; 8]) -> u8 { - b0 | (b1 << 1) | (b2 << 2) | (b3 << 3) | (b4 << 4) | (b5 << 5) | (b6 << 6) | (b7 << 7) - } - let query: Vec = [(); LENGTH] - .into_iter() - .map(|_| thread_rng().gen_range((-1.0 * LENGTH as f32)..(LENGTH as f32))) - .collect(); - let trans_vector = VectL2::::residual(&query, centroid); - let dis_v_2 = f32::reduce_sum_of_xy(&query, centroid); - let codes = - InfiniteByteChunks::new(trans_vector.iter().map(|e| e.is_sign_positive() as u8)) - .map(merge_8) - .take(trans_vector.len().div_ceil(8)) - .collect(); - fn estimate_failed(est: f32, b: f32, real: f32) -> bool { - let upper_bound = 2.0 * est - b; - b <= real && upper_bound >= real - } - (query, trans_vector, dis_v_2, codes, estimate_failed) - } -} diff --git a/crates/rabitq/src/quant/mod.rs b/crates/rabitq/src/quant/mod.rs deleted file mode 100644 index 73cc8aa79..000000000 --- a/crates/rabitq/src/quant/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod error; -pub mod quantization; -pub mod quantizer; diff --git a/crates/rabitq/src/quant/quantization.rs b/crates/rabitq/src/quant/quantization.rs deleted file mode 100644 index a9e829e2d..000000000 --- a/crates/rabitq/src/quant/quantization.rs +++ /dev/null @@ -1,192 +0,0 @@ -use super::quantizer::{Qvector, RabitqQuantizer}; -use crate::operator::OperatorRabitq; -use base::always_equal::AlwaysEqual; -use base::distance::Distance; -use base::index::VectorOptions; -use base::search::RerankerPop; -use common::json::Json; -use common::mmap_array::MmapArray; -use quantization::utils::InfiniteByteChunks; -use serde::{Deserialize, Serialize}; -use std::cmp::Reverse; -use std::ops::Range; -use std::path::Path; - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "")] -pub enum Quantizer { - Rabitq(RabitqQuantizer), -} - -impl Quantizer { - pub fn train(vector_options: VectorOptions) -> Self { - Self::Rabitq(RabitqQuantizer::train(vector_options)) - } -} - -pub enum RabitqPreprocessed { - Rabitq(Qvector), -} - -pub struct Quantization { - train: Json>, - codes: MmapArray, - packed_codes: MmapArray, - meta: MmapArray, -} - -impl Quantization { - pub fn create( - path: impl AsRef, - vector_options: VectorOptions, - n: u32, - vector_fetch: impl Fn(u32) -> Vec, - ) -> Self { - std::fs::create_dir(path.as_ref()).unwrap(); - fn merge_8([b0, b1, b2, b3, b4, b5, b6, b7]: [u8; 8]) -> u8 { - b0 | (b1 << 1) | (b2 << 2) | (b3 << 3) | (b4 << 4) | (b5 << 5) | (b6 << 6) | (b7 << 7) - } - fn merge_4([b0, b1, b2, b3]: [u8; 4]) -> u8 { - b0 | (b1 << 2) | (b2 << 4) | (b3 << 6) - } - fn merge_2([b0, b1]: [u8; 2]) -> u8 { - b0 | (b1 << 4) - } - let train = Quantizer::train(vector_options); - let train = Json::create(path.as_ref().join("train"), train); - let codes = MmapArray::create(path.as_ref().join("codes"), { - match &*train { - Quantizer::Rabitq(x) => Box::new((0..n).flat_map(|i| { - let vector = vector_fetch(i); - let codes = x.encode(&vector); - let bytes = x.bytes(); - match x.bits() { - 1 => InfiniteByteChunks::new(codes.into_iter()) - .map(merge_8) - .take(bytes as usize) - .collect(), - 2 => InfiniteByteChunks::new(codes.into_iter()) - .map(merge_4) - .take(bytes as usize) - .collect(), - 4 => InfiniteByteChunks::new(codes.into_iter()) - .map(merge_2) - .take(bytes as usize) - .collect(), - 8 => codes, - _ => unreachable!(), - } - })), - } - }); - let packed_codes = MmapArray::create( - path.as_ref().join("packed_codes"), - match &*train { - Quantizer::Rabitq(x) => { - use quantization::fast_scan::b4::{pack, BLOCK_SIZE}; - let blocks = n.div_ceil(BLOCK_SIZE); - Box::new((0..blocks).flat_map(|block| { - let t = x.dims().div_ceil(4); - let raw = std::array::from_fn::<_, { BLOCK_SIZE as _ }, _>(|i| { - let id = BLOCK_SIZE * block + i as u32; - let vector = vector_fetch(std::cmp::min(id, n - 1)); - let codes = x.encode(&vector); - InfiniteByteChunks::new(codes.into_iter()) - .map(|[b0, b1, b2, b3]| b0 | b1 << 1 | b2 << 2 | b3 << 3) - .take(t as usize) - .collect() - }); - pack(t, raw) - })) as Box> - } - }, - ); - let meta = MmapArray::create( - path.as_ref().join("meta"), - match &*train { - Quantizer::Rabitq(x) => Box::new((0..n).flat_map(|i| { - let vector = vector_fetch(i); - O::train_encode(x.dims(), vector).into_iter() - })), - }, - ); - Self { - train, - codes, - packed_codes, - meta, - } - } - - pub fn open(path: impl AsRef) -> Self { - let train = Json::open(path.as_ref().join("train")); - let codes = MmapArray::open(path.as_ref().join("codes")); - let packed_codes = MmapArray::open(path.as_ref().join("packed_codes")); - let meta = MmapArray::open(path.as_ref().join("meta")); - Self { - train, - codes, - packed_codes, - meta, - } - } - - pub fn preprocess(&self, trans_vector: &[f32], dis_v_2: f32) -> RabitqPreprocessed { - let (params, blut) = match &*self.train { - Quantizer::Rabitq(x) => x.preprocess(trans_vector, dis_v_2), - }; - RabitqPreprocessed::Rabitq(Qvector::Scan((params, blut))) - } - - pub fn fscan_preprocess(&self, trans_vector: &[f32], dis_v_2: f32) -> RabitqPreprocessed { - let (params, lut) = match &*self.train { - Quantizer::Rabitq(x) => x.fscan_preprocess(trans_vector, dis_v_2), - }; - RabitqPreprocessed::Rabitq(Qvector::FastScan((params, lut))) - } - - pub fn process(&self, preprocessed: &RabitqPreprocessed, u: u32) -> Distance { - match (&*self.train, preprocessed) { - (Quantizer::Rabitq(x), RabitqPreprocessed::Rabitq(Qvector::Scan((params, blut)))) => { - let bytes = x.bytes() as usize; - let start = u as usize * bytes; - let end = start + bytes; - let vector_params = O::train_decode(u, &self.meta); - let code = &self.codes[start..end]; - x.process(&vector_params, params, blut, code) - } - _ => unreachable!(), - } - } - - pub fn push_batch( - &self, - preprocessed: &RabitqPreprocessed, - range: Range, - heap: &mut Vec<(Reverse, AlwaysEqual)>, - rq_epsilon: f32, - ) { - match (&*self.train, preprocessed) { - (Quantizer::Rabitq(x), RabitqPreprocessed::Rabitq(qvector)) => x.push_batch( - qvector, - range, - heap, - &self.codes, - &self.packed_codes, - &self.meta, - rq_epsilon, - ), - } - } - - pub fn rerank<'a, T: 'a>( - &'a self, - heap: Vec<(Reverse, AlwaysEqual)>, - r: impl Fn(u32) -> (Distance, T) + 'a, - ) -> impl RerankerPop + 'a { - use Quantizer::*; - match &*self.train { - Rabitq(x) => x.rerank(heap, r), - } - } -} diff --git a/crates/rabitq/src/quant/quantizer.rs b/crates/rabitq/src/quant/quantizer.rs deleted file mode 100644 index 66819895b..000000000 --- a/crates/rabitq/src/quant/quantizer.rs +++ /dev/null @@ -1,247 +0,0 @@ -use super::error::ErrorFlatReranker; -use crate::operator::OperatorRabitq; -use base::always_equal::AlwaysEqual; -use base::distance::Distance; -use base::index::VectorOptions; -use base::search::RerankerPop; -use serde::{Deserialize, Serialize}; -use std::cmp::Reverse; -use std::marker::PhantomData; -use std::ops::Range; - -pub enum Qvector { - FastScan((O::QvectorParams, Vec)), - Scan((O::QvectorParams, O::QvectorLookup)), -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "")] -pub struct RabitqQuantizer { - dims: u32, - _maker: PhantomData O>, -} - -impl RabitqQuantizer { - pub fn train(vector_options: VectorOptions) -> Self { - let dims = vector_options.dims; - Self { - dims, - _maker: PhantomData, - } - } - - pub fn bits(&self) -> u32 { - 1 - } - - pub fn bytes(&self) -> u32 { - self.dims.div_ceil(8) - } - - pub fn dims(&self) -> u32 { - self.dims - } - - pub fn width(&self) -> u32 { - self.dims - } - - pub fn encode(&self, vector: &[f32]) -> Vec { - let mut codes = Vec::new(); - for i in 0..self.dims { - codes.push(vector[i as usize].is_sign_positive() as u8); - } - codes - } - - pub fn preprocess( - &self, - trans_vector: &[f32], - dis_v_2: f32, - ) -> (O::QvectorParams, O::QvectorLookup) { - O::preprocess(trans_vector, dis_v_2) - } - - pub fn fscan_preprocess( - &self, - trans_vector: &[f32], - dis_v_2: f32, - ) -> (O::QvectorParams, Vec) { - O::fscan_preprocess(trans_vector, dis_v_2) - } - - pub fn process( - &self, - vector_params: &O::VectorParams, - qvector_params: &O::QvectorParams, - qvector_lookup: &O::QvectorLookup, - qvector_code: &[u8], - ) -> Distance { - O::process(vector_params, qvector_code, qvector_params, qvector_lookup) - } - - pub fn process_lowerbound( - &self, - vector_params: &O::VectorParams, - qvector_params: &O::QvectorParams, - qvector_lookup: &O::QvectorLookup, - qvector_code: &[u8], - epsilon: f32, - ) -> Distance { - O::process_lowerbound( - vector_params, - qvector_code, - qvector_params, - qvector_lookup, - epsilon, - ) - } - - pub fn push_batch( - &self, - qvector: &Qvector, - range: Range, - heap: &mut Vec<(Reverse, AlwaysEqual)>, - codes: &[u8], - packed_codes: &[u8], - meta: &[f32], - epsilon: f32, - ) { - match qvector { - Qvector::FastScan((params, lut)) => { - self.push_back_fscan(params, lut, range, heap, packed_codes, meta, epsilon); - } - Qvector::Scan((params, blut)) => { - self.push_back_scan(params, blut, range, heap, codes, meta, epsilon); - } - } - } - - #[inline] - fn push_back_fscan( - &self, - qvector_params: &O::QvectorParams, - lut: &[u8], - rhs: Range, - heap: &mut Vec<(Reverse, AlwaysEqual)>, - packed_codes: &[u8], - meta: &[f32], - epsilon: f32, - ) { - use quantization::fast_scan::b4::{fast_scan_b4, BLOCK_SIZE}; - let s = rhs.start.next_multiple_of(BLOCK_SIZE); - let e = (rhs.end + 1 - BLOCK_SIZE).next_multiple_of(BLOCK_SIZE); - if rhs.start != s { - let i = s - BLOCK_SIZE; - let t = self.dims.div_ceil(4); - let bytes = (t * 16) as usize; - let start = (i / BLOCK_SIZE) as usize * bytes; - let end = start + bytes; - let all_binary_product = fast_scan_b4(t, &packed_codes[start..end], lut); - heap.extend({ - (rhs.start..s).map(|u| { - ( - Reverse({ - let params = &O::train_decode(u, meta); - let binary_prod = all_binary_product[(u - i) as usize]; - O::fscan_process_lowerbound( - params, - qvector_params, - binary_prod, - epsilon, - ) - }), - AlwaysEqual(u), - ) - }) - }); - } - for i in (s..e).step_by(BLOCK_SIZE as _) { - let t = self.dims.div_ceil(4); - let bytes = (t * 16) as usize; - let start = (i / BLOCK_SIZE) as usize * bytes; - let end = start + bytes; - let all_binary_product = fast_scan_b4(t, &packed_codes[start..end], lut); - heap.extend({ - (i..i + BLOCK_SIZE).map(|u| { - ( - Reverse({ - let params = &O::train_decode(u, meta); - let binary_prod = all_binary_product[(u - i) as usize]; - O::fscan_process_lowerbound( - params, - qvector_params, - binary_prod, - epsilon, - ) - }), - AlwaysEqual(u), - ) - }) - }); - } - if e != rhs.end { - let i = e; - let t = self.dims.div_ceil(4); - let bytes = (t * 16) as usize; - let start = (i / BLOCK_SIZE) as usize * bytes; - let end = start + bytes; - let all_binary_product = fast_scan_b4(t, &packed_codes[start..end], lut); - heap.extend({ - (e..rhs.end).map(|u| { - ( - Reverse({ - let params = &O::train_decode(u, meta); - let binary_prod = all_binary_product[(u - i) as usize]; - O::fscan_process_lowerbound( - params, - qvector_params, - binary_prod, - epsilon, - ) - }), - AlwaysEqual(u), - ) - }) - }); - } - } - - #[inline] - fn push_back_scan( - &self, - qvector_params: &O::QvectorParams, - qvector_lookup: &O::QvectorLookup, - rhs: Range, - heap: &mut Vec<(Reverse, AlwaysEqual)>, - codes: &[u8], - meta: &[f32], - epsilon: f32, - ) { - heap.extend(rhs.map(|u| { - ( - Reverse(self.process_lowerbound( - &O::train_decode(u, meta), - qvector_params, - qvector_lookup, - { - let bytes = self.bytes() as usize; - let start = u as usize * bytes; - let end = start + bytes; - &codes[start..end] - }, - epsilon, - )), - AlwaysEqual(u), - ) - })); - } - - pub fn rerank<'a, T: 'a>( - &'a self, - heap: Vec<(Reverse, AlwaysEqual)>, - rerank: impl Fn(u32) -> (Distance, T) + 'a, - ) -> impl RerankerPop + 'a { - ErrorFlatReranker::new(heap, rerank) - } -} diff --git a/crates/stoppable_rayon/src/lib.rs b/crates/stoppable_rayon/src/lib.rs index d2d6b38f5..482578594 100644 --- a/crates/stoppable_rayon/src/lib.rs +++ b/crates/stoppable_rayon/src/lib.rs @@ -4,26 +4,11 @@ use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; use std::sync::Arc; -pub use rayon::array; -pub use rayon::collections; -pub use rayon::iter; -pub use rayon::option; -pub use rayon::prelude; -pub use rayon::range; -pub use rayon::range_inclusive; -pub use rayon::result; -pub use rayon::slice; -pub use rayon::str; -pub use rayon::string; -pub use rayon::vec; - -pub use rayon::{current_num_threads, current_thread_index, max_num_threads}; -pub use rayon::{in_place_scope, in_place_scope_fifo}; -pub use rayon::{join, join_context}; -pub use rayon::{scope, scope_fifo}; -pub use rayon::{spawn, spawn_fifo}; -pub use rayon::{yield_local, yield_now}; -pub use rayon::{FnContext, Scope, ScopeFifo, Yield}; +pub mod iter { + pub use rayon::iter::IntoParallelIterator; + pub use rayon::iter::IntoParallelRefMutIterator; + pub use rayon::iter::ParallelIterator; +} #[derive(Debug, Default)] pub struct ThreadPoolBuilder { diff --git a/src/gucs/executing.rs b/src/gucs/executing.rs index d5d2ba465..4366514a5 100644 --- a/src/gucs/executing.rs +++ b/src/gucs/executing.rs @@ -13,24 +13,15 @@ static PQ_RERANK_SIZE: GucSetting = static PQ_FAST_SCAN: GucSetting = GucSetting::::new(SearchOptions::default_pq_fast_scan()); +static RQ_FAST_SCAN: GucSetting = + GucSetting::::new(SearchOptions::default_rq_fast_scan()); + static IVF_NPROBE: GucSetting = GucSetting::::new(SearchOptions::default_ivf_nprobe() as i32); static HNSW_EF_SEARCH: GucSetting = GucSetting::::new(SearchOptions::default_hnsw_ef_search() as i32); -static RABITQ_NPROBE: GucSetting = - GucSetting::::new(SearchOptions::default_rabitq_nprobe() as i32); - -static RABITQ_EPSILON: GucSetting = - GucSetting::::new(SearchOptions::default_rabitq_epsilon() as f64); - -static RABITQ_FAST_SCAN: GucSetting = - GucSetting::::new(SearchOptions::default_rabitq_fast_scan()); - -static DISKANN_EF_SEARCH: GucSetting = - GucSetting::::new(SearchOptions::default_diskann_ef_search() as i32); - pub unsafe fn init() { GucRegistry::define_int_guc( "vectors.sq_rerank_size", @@ -68,6 +59,14 @@ pub unsafe fn init() { GucContext::Userset, GucFlags::default(), ); + GucRegistry::define_bool_guc( + "vectors.rq_fast_scan", + "Enables fast scan or not.", + "https://docs.pgvecto.rs/usage/search.html", + &PQ_FAST_SCAN, + GucContext::Userset, + GucFlags::default(), + ); GucRegistry::define_int_guc( "vectors.ivf_nprobe", "`nprobe` argument of IVF algorithm.", @@ -88,44 +87,6 @@ pub unsafe fn init() { GucContext::Userset, GucFlags::default(), ); - GucRegistry::define_int_guc( - "vectors.rabitq_nprobe", - "`nprobe` argument of RaBitQ algorithm.", - "https://docs.pgvecto.rs/usage/search.html", - &RABITQ_NPROBE, - 1, - u16::MAX as _, - GucContext::Userset, - GucFlags::default(), - ); - GucRegistry::define_float_guc( - "vectors.rabitq_epsilon", - "`epsilon` argument of RaBitQ algorithm.", - "https://docs.pgvecto.rs/usage/search.html", - &RABITQ_EPSILON, - 1.0, - 4.0, - GucContext::Userset, - GucFlags::default(), - ); - GucRegistry::define_bool_guc( - "vectors.rabitq_fast_scan", - "Enables fast scan or not.", - "https://docs.pgvecto.rs/usage/search.html", - &RABITQ_FAST_SCAN, - GucContext::Userset, - GucFlags::default(), - ); - GucRegistry::define_int_guc( - "vectors.diskann_ef_search", - "`ef_search` argument of DiskANN algorithm.", - "https://docs.pgvecto.rs/usage/search.html", - &DISKANN_EF_SEARCH, - 1, - u16::MAX as _, - GucContext::Userset, - GucFlags::default(), - ); } pub fn search_options() -> SearchOptions { @@ -134,11 +95,8 @@ pub fn search_options() -> SearchOptions { sq_fast_scan: SQ_FAST_SCAN.get(), pq_rerank_size: PQ_RERANK_SIZE.get() as u32, pq_fast_scan: PQ_FAST_SCAN.get(), + rq_fast_scan: RQ_FAST_SCAN.get(), ivf_nprobe: IVF_NPROBE.get() as u32, hnsw_ef_search: HNSW_EF_SEARCH.get() as u32, - rabitq_nprobe: RABITQ_NPROBE.get() as u32, - rabitq_epsilon: RABITQ_EPSILON.get() as f32, - rabitq_fast_scan: RABITQ_FAST_SCAN.get(), - diskann_ef_search: DISKANN_EF_SEARCH.get() as u32, } }