Skip to content

Commit

Permalink
feat: do not generate slow-scan lookup table if fast-scan is enabled (#…
Browse files Browse the repository at this point in the history
…583)

Signed-off-by: usamoi <usamoi@outlook.com>
  • Loading branch information
usamoi authored Sep 2, 2024
1 parent 590a9d6 commit 3b7c694
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 136 deletions.
28 changes: 17 additions & 11 deletions crates/rabitq/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,25 @@ impl<O: Op> Rabitq<O> {
);
let mut heap = Vec::new();
for &(_, i) in lists.iter() {
let preprocessed = self.quantization.preprocess(&O::residual(
&projected_query,
&self.projected_centroids[(i,)],
));
let preprocessed = if opts.rabitq_fast_scan {
self.quantization
.fscan_preprocess(&O::residual(
&projected_query,
&self.projected_centroids[(i,)],
))
.into()
} else {
self.quantization
.preprocess(&O::residual(
&projected_query,
&self.projected_centroids[(i,)],
))
.into()
};
let start = self.offsets[i];
let end = self.offsets[i + 1];
self.quantization.push_batch(
&preprocessed,
start..end,
&mut heap,
opts.rabitq_epsilon,
opts.rabitq_fast_scan,
);
self.quantization
.push_batch(&preprocessed, start..end, &mut heap, opts.rabitq_epsilon);
}
let mut reranker = self.quantization.rerank(heap, move |u| {
(O::distance(vector, self.storage.vector(u)), ())
Expand Down
58 changes: 35 additions & 23 deletions crates/rabitq/src/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,38 @@ pub trait OperatorRabitq: OperatorStorage {
fn residual(lhs: &[f32], rhs: &[f32]) -> Vec<f32>;
fn proj(projection: &[Vec<f32>], vector: &[f32]) -> Vec<f32>;

type Preprocessed0;
type Preprocessed1;
type Params;

fn preprocess(vector: &[f32]) -> (Self::Preprocessed0, Self::Preprocessed1);
type Preprocessed;

fn preprocess(vector: &[f32]) -> (Self::Params, Self::Preprocessed);
fn process(
dis_u_2: f32,
factor_ppc: f32,
factor_ip: f32,
factor_err: f32,
code: &[u8],
p0: &Self::Preprocessed0,
p1: &Self::Preprocessed1,
p0: &Self::Params,
p1: &Self::Preprocessed,
) -> Distance;
fn process_lowerbound(
dis_u_2: f32,
factor_ppc: f32,
factor_ip: f32,
factor_err: f32,
code: &[u8],
p0: &Self::Preprocessed0,
p1: &Self::Preprocessed1,
p0: &Self::Params,
p1: &Self::Preprocessed,
epsilon: f32,
) -> Distance;
fn fscan_preprocess(preprocessed: &Self::Preprocessed1) -> Vec<u8>;

fn fscan_preprocess(vector: &[f32]) -> (Self::Params, Vec<u8>);
fn fscan_process_lowerbound(
dis_u_2: f32,
factor_ppc: f32,
factor_ip: f32,
factor_err: f32,
p0: &Self::Preprocessed0,
p0: &Self::Params,
param: u16,
epsilon: f32,
) -> Distance;
Expand Down Expand Up @@ -71,8 +73,9 @@ impl OperatorRabitq for VectL2<f32> {
.collect()
}

type Preprocessed0 = (f32, f32, f32, f32);
type Preprocessed1 = ((Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>), Vec<u8>);
type Params = (f32, f32, f32, f32);

type Preprocessed = ((Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>), Vec<u8>);

fn preprocess(
vector: &[f32],
Expand Down Expand Up @@ -122,16 +125,25 @@ impl OperatorRabitq for VectL2<f32> {
Distance::from_f32(rough - epsilon * err)
}

fn fscan_preprocess(preprocessed: &((Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>), Vec<u8>)) -> Vec<u8> {
preprocessed.1.clone()
fn fscan_preprocess(vector: &[f32]) -> ((f32, f32, f32, f32), Vec<u8>) {
use quantization::quantize;
let dis_v_2 = f32::reduce_sum_of_x2(vector);
let (k, b, qvector) = quantize::quantize::<15>(vector);
let qvector_sum = if vector.len() <= 4369 {
quantize::reduce_sum_of_x_as_u16(&qvector) as f32
} else {
quantize::reduce_sum_of_x_as_u32(&qvector) as f32
};
let lut = gen(qvector);
((dis_v_2, b, k, qvector_sum), lut)
}

fn fscan_process_lowerbound(
dis_u_2: f32,
factor_ppc: f32,
factor_ip: f32,
factor_err: f32,
p0: &Self::Preprocessed0,
p0: &Self::Params,
param: u16,
epsilon: f32,
) -> Distance {
Expand Down Expand Up @@ -159,10 +171,10 @@ macro_rules! unimpl_operator_rabitq {
unimplemented!()
}

type Preprocessed0 = std::convert::Infallible;
type Preprocessed1 = std::convert::Infallible;
type Params = std::convert::Infallible;
type Preprocessed = std::convert::Infallible;

fn preprocess(_: &[f32]) -> (Self::Preprocessed0, Self::Preprocessed1) {
fn preprocess(_: &[f32]) -> (Self::Params, Self::Preprocessed) {
unimplemented!()
}

Expand All @@ -172,8 +184,8 @@ macro_rules! unimpl_operator_rabitq {
_: f32,
_: f32,
_: &[u8],
_: &Self::Preprocessed0,
_: &Self::Preprocessed1,
_: &Self::Params,
_: &Self::Preprocessed,
) -> Distance {
unimplemented!()
}
Expand All @@ -184,14 +196,14 @@ macro_rules! unimpl_operator_rabitq {
_: f32,
_: f32,
_: &[u8],
_: &Self::Preprocessed0,
_: &Self::Preprocessed1,
_: &Self::Params,
_: &Self::Preprocessed,
_: f32,
) -> Distance {
unimplemented!()
}

fn fscan_preprocess(_: &Self::Preprocessed1) -> Vec<u8> {
fn fscan_preprocess(_: &[f32]) -> (Self::Params, Vec<u8>) {
unimplemented!()
}

Expand All @@ -200,7 +212,7 @@ macro_rules! unimpl_operator_rabitq {
_: f32,
_: f32,
_: f32,
_: &Self::Preprocessed0,
_: &Self::Params,
_: u16,
_: f32,
) -> Distance {
Expand Down
48 changes: 41 additions & 7 deletions crates/rabitq/src/quant/quantization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,37 @@ impl<O: OperatorRabitq> Quantizer<O> {
pub enum QuantizationPreprocessed<O: OperatorRabitq> {
Rabitq(
(
<O as OperatorRabitq>::Preprocessed0,
<O as OperatorRabitq>::Preprocessed1,
<O as OperatorRabitq>::Params,
<O as OperatorRabitq>::Preprocessed,
),
),
}

impl<O: OperatorRabitq> From<QuantizationPreprocessed<O>> for QuantizationAnyPreprocessed<O> {
fn from(value: QuantizationPreprocessed<O>) -> Self {
match value {
QuantizationPreprocessed::Rabitq((param, blut)) => Self::Rabitq((param, Ok(blut))),
}
}
}

pub enum QuantizationFscanPreprocessed<O: OperatorRabitq> {
Rabitq((<O as OperatorRabitq>::Params, Vec<u8>)),
}

impl<O: OperatorRabitq> From<QuantizationFscanPreprocessed<O>> for QuantizationAnyPreprocessed<O> {
fn from(value: QuantizationFscanPreprocessed<O>) -> Self {
match value {
QuantizationFscanPreprocessed::Rabitq((param, lut)) => Self::Rabitq((param, Err(lut))),
}
}
}

pub enum QuantizationAnyPreprocessed<O: OperatorRabitq> {
Rabitq(
(
<O as OperatorRabitq>::Params,
Result<<O as OperatorRabitq>::Preprocessed, Vec<u8>>,
),
),
}
Expand Down Expand Up @@ -141,6 +170,12 @@ impl<O: OperatorRabitq> Quantization<O> {
}
}

pub fn fscan_preprocess(&self, lhs: &[f32]) -> QuantizationFscanPreprocessed<O> {
match &*self.train {
Quantizer::Rabitq(x) => QuantizationFscanPreprocessed::Rabitq(x.fscan_preprocess(lhs)),
}
}

pub fn process(&self, preprocessed: &QuantizationPreprocessed<O>, u: u32) -> Distance {
match (&*self.train, preprocessed) {
(Quantizer::Rabitq(x), QuantizationPreprocessed::Rabitq(lhs)) => {
Expand All @@ -159,22 +194,21 @@ impl<O: OperatorRabitq> Quantization<O> {

pub fn push_batch(
&self,
preprocessed: &QuantizationPreprocessed<O>,
preprocessed: &QuantizationAnyPreprocessed<O>,
rhs: Range<u32>,
heap: &mut Vec<(Reverse<Distance>, AlwaysEqual<u32>)>,
rq_epsilon: f32,
rq_fast_scan: bool,
) {
match (&*self.train, preprocessed) {
(Quantizer::Rabitq(x), QuantizationPreprocessed::Rabitq(lhs)) => x.push_batch(
lhs,
(Quantizer::Rabitq(x), QuantizationAnyPreprocessed::Rabitq((a, b))) => x.push_batch(
a,
b,
rhs,
heap,
&self.codes,
&self.packed_codes,
&self.meta,
rq_epsilon,
rq_fast_scan,
),
}
}
Expand Down
Loading

0 comments on commit 3b7c694

Please sign in to comment.