diff --git a/crates/rabitq/src/lib.rs b/crates/rabitq/src/lib.rs index 67c03fb2a..60116ebc8 100644 --- a/crates/rabitq/src/lib.rs +++ b/crates/rabitq/src/lib.rs @@ -75,19 +75,25 @@ impl Rabitq { ); let mut heap = Vec::new(); for &(_, i) in lists.iter() { - let preprocessed = self.quantization.preprocess(&O::residual( - &projected_query, - &self.projected_centroids[(i,)], - )); + let preprocessed = if opts.rabitq_fast_scan { + self.quantization + .fscan_preprocess(&O::residual( + &projected_query, + &self.projected_centroids[(i,)], + )) + .into() + } else { + self.quantization + .preprocess(&O::residual( + &projected_query, + &self.projected_centroids[(i,)], + )) + .into() + }; let start = self.offsets[i]; let end = self.offsets[i + 1]; - self.quantization.push_batch( - &preprocessed, - start..end, - &mut heap, - opts.rabitq_epsilon, - opts.rabitq_fast_scan, - ); + self.quantization + .push_batch(&preprocessed, start..end, &mut heap, opts.rabitq_epsilon); } let mut reranker = self.quantization.rerank(heap, move |u| { (O::distance(vector, self.storage.vector(u)), ()) diff --git a/crates/rabitq/src/operator.rs b/crates/rabitq/src/operator.rs index 5c5cf55d4..8fad76b12 100644 --- a/crates/rabitq/src/operator.rs +++ b/crates/rabitq/src/operator.rs @@ -13,18 +13,19 @@ pub trait OperatorRabitq: OperatorStorage { fn residual(lhs: &[f32], rhs: &[f32]) -> Vec; fn proj(projection: &[Vec], vector: &[f32]) -> Vec; - type Preprocessed0; - type Preprocessed1; + type Params; - fn preprocess(vector: &[f32]) -> (Self::Preprocessed0, Self::Preprocessed1); + type Preprocessed; + + fn preprocess(vector: &[f32]) -> (Self::Params, Self::Preprocessed); fn process( dis_u_2: f32, factor_ppc: f32, factor_ip: f32, factor_err: f32, code: &[u8], - p0: &Self::Preprocessed0, - p1: &Self::Preprocessed1, + p0: &Self::Params, + p1: &Self::Preprocessed, ) -> Distance; fn process_lowerbound( dis_u_2: f32, @@ -32,17 +33,18 @@ pub trait OperatorRabitq: OperatorStorage { factor_ip: f32, factor_err: f32, code: &[u8], - p0: &Self::Preprocessed0, - p1: &Self::Preprocessed1, + p0: &Self::Params, + p1: &Self::Preprocessed, epsilon: f32, ) -> Distance; - fn fscan_preprocess(preprocessed: &Self::Preprocessed1) -> Vec; + + fn fscan_preprocess(vector: &[f32]) -> (Self::Params, Vec); fn fscan_process_lowerbound( dis_u_2: f32, factor_ppc: f32, factor_ip: f32, factor_err: f32, - p0: &Self::Preprocessed0, + p0: &Self::Params, param: u16, epsilon: f32, ) -> Distance; @@ -71,8 +73,9 @@ impl OperatorRabitq for VectL2 { .collect() } - type Preprocessed0 = (f32, f32, f32, f32); - type Preprocessed1 = ((Vec, Vec, Vec, Vec), Vec); + type Params = (f32, f32, f32, f32); + + type Preprocessed = ((Vec, Vec, Vec, Vec), Vec); fn preprocess( vector: &[f32], @@ -122,8 +125,17 @@ impl OperatorRabitq for VectL2 { Distance::from_f32(rough - epsilon * err) } - fn fscan_preprocess(preprocessed: &((Vec, Vec, Vec, Vec), Vec)) -> Vec { - preprocessed.1.clone() + fn fscan_preprocess(vector: &[f32]) -> ((f32, f32, f32, f32), Vec) { + use quantization::quantize; + let dis_v_2 = f32::reduce_sum_of_x2(vector); + let (k, b, qvector) = quantize::quantize::<15>(vector); + let qvector_sum = if vector.len() <= 4369 { + quantize::reduce_sum_of_x_as_u16(&qvector) as f32 + } else { + quantize::reduce_sum_of_x_as_u32(&qvector) as f32 + }; + let lut = gen(qvector); + ((dis_v_2, b, k, qvector_sum), lut) } fn fscan_process_lowerbound( @@ -131,7 +143,7 @@ impl OperatorRabitq for VectL2 { factor_ppc: f32, factor_ip: f32, factor_err: f32, - p0: &Self::Preprocessed0, + p0: &Self::Params, param: u16, epsilon: f32, ) -> Distance { @@ -159,10 +171,10 @@ macro_rules! unimpl_operator_rabitq { unimplemented!() } - type Preprocessed0 = std::convert::Infallible; - type Preprocessed1 = std::convert::Infallible; + type Params = std::convert::Infallible; + type Preprocessed = std::convert::Infallible; - fn preprocess(_: &[f32]) -> (Self::Preprocessed0, Self::Preprocessed1) { + fn preprocess(_: &[f32]) -> (Self::Params, Self::Preprocessed) { unimplemented!() } @@ -172,8 +184,8 @@ macro_rules! unimpl_operator_rabitq { _: f32, _: f32, _: &[u8], - _: &Self::Preprocessed0, - _: &Self::Preprocessed1, + _: &Self::Params, + _: &Self::Preprocessed, ) -> Distance { unimplemented!() } @@ -184,14 +196,14 @@ macro_rules! unimpl_operator_rabitq { _: f32, _: f32, _: &[u8], - _: &Self::Preprocessed0, - _: &Self::Preprocessed1, + _: &Self::Params, + _: &Self::Preprocessed, _: f32, ) -> Distance { unimplemented!() } - fn fscan_preprocess(_: &Self::Preprocessed1) -> Vec { + fn fscan_preprocess(_: &[f32]) -> (Self::Params, Vec) { unimplemented!() } @@ -200,7 +212,7 @@ macro_rules! unimpl_operator_rabitq { _: f32, _: f32, _: f32, - _: &Self::Preprocessed0, + _: &Self::Params, _: u16, _: f32, ) -> Distance { diff --git a/crates/rabitq/src/quant/quantization.rs b/crates/rabitq/src/quant/quantization.rs index 6922a4ec2..95cd161d0 100644 --- a/crates/rabitq/src/quant/quantization.rs +++ b/crates/rabitq/src/quant/quantization.rs @@ -27,8 +27,37 @@ impl Quantizer { pub enum QuantizationPreprocessed { Rabitq( ( - ::Preprocessed0, - ::Preprocessed1, + ::Params, + ::Preprocessed, + ), + ), +} + +impl From> for QuantizationAnyPreprocessed { + fn from(value: QuantizationPreprocessed) -> Self { + match value { + QuantizationPreprocessed::Rabitq((param, blut)) => Self::Rabitq((param, Ok(blut))), + } + } +} + +pub enum QuantizationFscanPreprocessed { + Rabitq((::Params, Vec)), +} + +impl From> for QuantizationAnyPreprocessed { + fn from(value: QuantizationFscanPreprocessed) -> Self { + match value { + QuantizationFscanPreprocessed::Rabitq((param, lut)) => Self::Rabitq((param, Err(lut))), + } + } +} + +pub enum QuantizationAnyPreprocessed { + Rabitq( + ( + ::Params, + Result<::Preprocessed, Vec>, ), ), } @@ -141,6 +170,12 @@ impl Quantization { } } + pub fn fscan_preprocess(&self, lhs: &[f32]) -> QuantizationFscanPreprocessed { + match &*self.train { + Quantizer::Rabitq(x) => QuantizationFscanPreprocessed::Rabitq(x.fscan_preprocess(lhs)), + } + } + pub fn process(&self, preprocessed: &QuantizationPreprocessed, u: u32) -> Distance { match (&*self.train, preprocessed) { (Quantizer::Rabitq(x), QuantizationPreprocessed::Rabitq(lhs)) => { @@ -159,22 +194,21 @@ impl Quantization { pub fn push_batch( &self, - preprocessed: &QuantizationPreprocessed, + preprocessed: &QuantizationAnyPreprocessed, rhs: Range, heap: &mut Vec<(Reverse, AlwaysEqual)>, rq_epsilon: f32, - rq_fast_scan: bool, ) { match (&*self.train, preprocessed) { - (Quantizer::Rabitq(x), QuantizationPreprocessed::Rabitq(lhs)) => x.push_batch( - lhs, + (Quantizer::Rabitq(x), QuantizationAnyPreprocessed::Rabitq((a, b))) => x.push_batch( + a, + b, rhs, heap, &self.codes, &self.packed_codes, &self.meta, rq_epsilon, - rq_fast_scan, ), } } diff --git a/crates/rabitq/src/quant/quantizer.rs b/crates/rabitq/src/quant/quantizer.rs index 3142cb691..9d5bac3ae 100644 --- a/crates/rabitq/src/quant/quantizer.rs +++ b/crates/rabitq/src/quant/quantizer.rs @@ -68,14 +68,18 @@ impl RabitqQuantizer { (dis_u * dis_u, factor_ppc, factor_ip, factor_err, codes) } - pub fn preprocess(&self, lhs: &[f32]) -> (O::Preprocessed0, O::Preprocessed1) { + pub fn preprocess(&self, lhs: &[f32]) -> (O::Params, O::Preprocessed) { O::preprocess(lhs) } + pub fn fscan_preprocess(&self, lhs: &[f32]) -> (O::Params, Vec) { + O::fscan_preprocess(lhs) + } + pub fn process( &self, - p0: &O::Preprocessed0, - p1: &O::Preprocessed1, + p0: &O::Params, + p1: &O::Preprocessed, (a, b, c, d, e): (f32, f32, f32, f32, &[u8]), ) -> Distance { O::process(a, b, c, d, e, p0, p1) @@ -83,8 +87,8 @@ impl RabitqQuantizer { pub fn process_lowerbound( &self, - p0: &O::Preprocessed0, - p1: &O::Preprocessed1, + p0: &O::Params, + p1: &O::Preprocessed, (a, b, c, d, e): (f32, f32, f32, f32, &[u8]), epsilon: f32, ) -> Distance { @@ -93,110 +97,112 @@ impl RabitqQuantizer { pub fn push_batch( &self, - (p0, p1): &(O::Preprocessed0, O::Preprocessed1), + alpha: &O::Params, + beta: &Result>, rhs: Range, heap: &mut Vec<(Reverse, AlwaysEqual)>, codes: &[u8], packed_codes: &[u8], meta: &[f32], epsilon: f32, - fast_scan: bool, ) { - if fast_scan { - use quantization::fast_scan::b4::{fast_scan_b4, BLOCK_SIZE}; - let s = rhs.start.next_multiple_of(BLOCK_SIZE); - let e = (rhs.end + 1 - BLOCK_SIZE).next_multiple_of(BLOCK_SIZE); - let lut = O::fscan_preprocess(p1); - if rhs.start != s { - let i = s - BLOCK_SIZE; - let t = self.dims.div_ceil(4); - let bytes = (t * 16) as usize; - let start = (i / BLOCK_SIZE) as usize * bytes; - let end = start + bytes; - let res = fast_scan_b4(t, &packed_codes[start..end], &lut); - heap.extend({ - (rhs.start..s).map(|u| { - ( - Reverse({ - let a = meta[4 * u as usize + 0]; - let b = meta[4 * u as usize + 1]; - let c = meta[4 * u as usize + 2]; - let d = meta[4 * u as usize + 3]; - let param = res[(u - i) as usize]; - O::fscan_process_lowerbound(a, b, c, d, p0, param, epsilon) - }), - AlwaysEqual(u), - ) - }) - }); - } - for i in (s..e).step_by(BLOCK_SIZE as _) { - let t = self.dims.div_ceil(4); - let bytes = (t * 16) as usize; - let start = (i / BLOCK_SIZE) as usize * bytes; - let end = start + bytes; - let res = fast_scan_b4(t, &packed_codes[start..end], &lut); - heap.extend({ - (i..i + BLOCK_SIZE).map(|u| { - ( - Reverse({ - let a = meta[4 * u as usize + 0]; - let b = meta[4 * u as usize + 1]; - let c = meta[4 * u as usize + 2]; - let d = meta[4 * u as usize + 3]; - let param = res[(u - i) as usize]; - O::fscan_process_lowerbound(a, b, c, d, p0, param, epsilon) - }), - AlwaysEqual(u), - ) - }) - }); + match beta { + Err(lut) => { + use quantization::fast_scan::b4::{fast_scan_b4, BLOCK_SIZE}; + let s = rhs.start.next_multiple_of(BLOCK_SIZE); + let e = (rhs.end + 1 - BLOCK_SIZE).next_multiple_of(BLOCK_SIZE); + if rhs.start != s { + let i = s - BLOCK_SIZE; + let t = self.dims.div_ceil(4); + let bytes = (t * 16) as usize; + let start = (i / BLOCK_SIZE) as usize * bytes; + let end = start + bytes; + let res = fast_scan_b4(t, &packed_codes[start..end], lut); + heap.extend({ + (rhs.start..s).map(|u| { + ( + Reverse({ + let a = meta[4 * u as usize + 0]; + let b = meta[4 * u as usize + 1]; + let c = meta[4 * u as usize + 2]; + let d = meta[4 * u as usize + 3]; + let param = res[(u - i) as usize]; + O::fscan_process_lowerbound(a, b, c, d, alpha, param, epsilon) + }), + AlwaysEqual(u), + ) + }) + }); + } + for i in (s..e).step_by(BLOCK_SIZE as _) { + let t = self.dims.div_ceil(4); + let bytes = (t * 16) as usize; + let start = (i / BLOCK_SIZE) as usize * bytes; + let end = start + bytes; + let res = fast_scan_b4(t, &packed_codes[start..end], lut); + heap.extend({ + (i..i + BLOCK_SIZE).map(|u| { + ( + Reverse({ + let a = meta[4 * u as usize + 0]; + let b = meta[4 * u as usize + 1]; + let c = meta[4 * u as usize + 2]; + let d = meta[4 * u as usize + 3]; + let param = res[(u - i) as usize]; + O::fscan_process_lowerbound(a, b, c, d, alpha, param, epsilon) + }), + AlwaysEqual(u), + ) + }) + }); + } + if e != rhs.end { + let i = e; + let t = self.dims.div_ceil(4); + let bytes = (t * 16) as usize; + let start = (i / BLOCK_SIZE) as usize * bytes; + let end = start + bytes; + let res = fast_scan_b4(t, &packed_codes[start..end], lut); + heap.extend({ + (e..rhs.end).map(|u| { + ( + Reverse({ + let a = meta[4 * u as usize + 0]; + let b = meta[4 * u as usize + 1]; + let c = meta[4 * u as usize + 2]; + let d = meta[4 * u as usize + 3]; + let param = res[(u - i) as usize]; + O::fscan_process_lowerbound(a, b, c, d, alpha, param, epsilon) + }), + AlwaysEqual(u), + ) + }) + }); + } } - if e != rhs.end { - let i = e; - let t = self.dims.div_ceil(4); - let bytes = (t * 16) as usize; - let start = (i / BLOCK_SIZE) as usize * bytes; - let end = start + bytes; - let res = fast_scan_b4(t, &packed_codes[start..end], &lut); - heap.extend({ - (e..rhs.end).map(|u| { - ( - Reverse({ + Ok(blut) => { + heap.extend(rhs.map(|u| { + ( + Reverse(self.process_lowerbound( + alpha, + blut, + { + let bytes = self.bytes() as usize; + let start = u as usize * bytes; + let end = start + bytes; let a = meta[4 * u as usize + 0]; let b = meta[4 * u as usize + 1]; let c = meta[4 * u as usize + 2]; let d = meta[4 * u as usize + 3]; - let param = res[(u - i) as usize]; - O::fscan_process_lowerbound(a, b, c, d, p0, param, epsilon) - }), - AlwaysEqual(u), - ) - }) - }); + (a, b, c, d, &codes[start..end]) + }, + epsilon, + )), + AlwaysEqual(u), + ) + })); } - return; } - heap.extend(rhs.map(|u| { - ( - Reverse(self.process_lowerbound( - p0, - p1, - { - let bytes = self.bytes() as usize; - let start = u as usize * bytes; - let end = start + bytes; - let a = meta[4 * u as usize + 0]; - let b = meta[4 * u as usize + 1]; - let c = meta[4 * u as usize + 2]; - let d = meta[4 * u as usize + 3]; - (a, b, c, d, &codes[start..end]) - }, - epsilon, - )), - AlwaysEqual(u), - ) - })); } pub fn rerank<'a, T: 'a>(