Skip to content

Commit

Permalink
refactor: make scalar quantization faster
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <usamoi@outlook.com>
  • Loading branch information
usamoi committed Sep 23, 2024
1 parent a79b9f3 commit 04be9b2
Showing 1 changed file with 121 additions and 59 deletions.
180 changes: 121 additions & 59 deletions crates/quantization/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,31 @@ use base::search::RerankerPop;
use base::search::RerankerPush;
use base::search::Vectors;
use base::vector::*;
use common::vec2::Vec2;
use rayon::iter::IntoParallelIterator;
use rayon::iter::ParallelIterator;
use serde::Deserialize;
use serde::Serialize;
use std::cmp::Reverse;
use std::marker::PhantomData;
use std::ops::Range;
use stoppable_rayon as rayon;

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(bound = "")]
pub struct ScalarQuantizer<O: OperatorScalarQuantization> {
dims: u32,
bits: u32,
max: Vec<f32>,
min: Vec<f32>,
centroids: Vec2<f32>,
max: Vec<f32>,
_phantom: PhantomData<fn(O) -> O>,
}

impl<O: OperatorScalarQuantization> Quantizer<O> for ScalarQuantizer<O> {
fn train(
vector_options: VectorOptions,
options: Option<QuantizationOptions>,
vectors: &impl Vectors<O::Vector>,
transform: impl Fn(Borrowed<'_, O>) -> O::Vector + Copy,
vectors: &(impl Vectors<O::Vector> + Sync),
transform: impl Fn(Borrowed<'_, O>) -> O::Vector + Copy + Sync,
) -> Self {
let options = if let Some(QuantizationOptions::Scalar(x)) = options {
x
Expand All @@ -50,32 +51,46 @@ impl<O: OperatorScalarQuantization> Quantizer<O> for ScalarQuantizer<O> {
};
let dims = vector_options.dims;
let bits = options.bits;
let mut max = vec![f32::NEG_INFINITY; dims as usize];
let mut min = vec![f32::INFINITY; dims as usize];
let n = vectors.len();
for i in 0..n {
let vector = transform(vectors.vector(i));
let vector = vector.as_borrowed();
for j in 0..dims {
min[j as usize] = min[j as usize].min(O::get(vector, j).to_f32());
max[j as usize] = max[j as usize].max(O::get(vector, j).to_f32());
}
}
let mut centroids = Vec2::zeros((1 << bits, dims as usize));
for i in 0..dims {
let bas = min[i as usize];
let del = max[i as usize] - min[i as usize];
for j in 0_usize..(1 << bits) {
let val = j as f32 / ((1 << bits) - 1) as f32;
centroids[(j, i as usize)] = bas + val * del;
}
}
let (min, max) = (0..n)
.into_par_iter()
.fold(
|| {
(
vec![f32::INFINITY; dims as usize],
vec![f32::NEG_INFINITY; dims as usize],
)
},
|(mut min, mut max), i| {
let vector = transform(vectors.vector(i));
let vector = vector.as_borrowed();
for j in 0..dims {
min[j as usize] = min[j as usize].min(O::get(vector, j).to_f32());
max[j as usize] = max[j as usize].max(O::get(vector, j).to_f32());
}
(min, max)
},
)
.reduce(
|| {
(
vec![f32::INFINITY; dims as usize],
vec![f32::NEG_INFINITY; dims as usize],
)
},
|(mut min, mut max), (rmin, rmax)| {
for j in 0..dims {
min[j as usize] = min[j as usize].min(rmin[j as usize]);
max[j as usize] = max[j as usize].max(rmax[j as usize]);
}
(min, max)
},
);
Self {
dims,
bits,
max,
min,
centroids,
max,
_phantom: PhantomData,
}
}
Expand Down Expand Up @@ -150,7 +165,7 @@ impl<O: OperatorScalarQuantization> Quantizer<O> for ScalarQuantizer<O> {
type Lut = Vec<f32>;

fn preprocess(&self, vector: Borrowed<'_, O>) -> Self::Lut {
O::preprocess(self.dims, self.bits, &self.max, &self.min, vector)
O::preprocess(self.dims, self.bits, &self.min, &self.max, vector)
}

fn process(&self, lut: &Self::Lut, code: &[u8], _: Borrowed<'_, O>) -> Distance {
Expand All @@ -165,7 +180,7 @@ impl<O: OperatorScalarQuantization> Quantizer<O> for ScalarQuantizer<O> {
);

fn fscan_preprocess(&self, vector: Borrowed<'_, O>) -> Self::FLut {
O::fscan_preprocess(self.dims, self.bits, &self.max, &self.min, vector)
O::fscan_preprocess(self.dims, self.bits, &self.min, &self.max, vector)
}

fn fscan_process(&self, flut: &Self::FLut, code: &[u8]) -> [Distance; 32] {
Expand Down Expand Up @@ -291,17 +306,17 @@ pub trait OperatorScalarQuantization: Operator {
fn preprocess(
dims: u32,
bits: u32,
max: &[f32],
min: &[f32],
max: &[f32],
vector: Borrowed<'_, Self>,
) -> Vec<f32>;
fn process(dims: u32, bits: u32, lut: &[f32], code: &[u8]) -> Distance;

fn fscan_preprocess(
dims: u32,
bits: u32,
max: &[f32],
min: &[f32],
max: &[f32],
vector: Borrowed<'_, Self>,
) -> (u32, f32, f32, Vec<u8>);
fn fscan_process(flut: &(u32, f32, f32, Vec<u8>), code: &[u8]) -> [Distance; 32];
Expand All @@ -316,22 +331,46 @@ impl<S: ScalarLike> OperatorScalarQuantization for VectDot<S> {
fn preprocess(
dims: u32,
bits: u32,
max: &[f32],
min: &[f32],
max: &[f32],
vector: Borrowed<'_, Self>,
) -> Vec<f32> {
let mut xy = Vec::with_capacity(dims as _);
for i in 0..dims {
let bas = min[i as usize];
let del = max[i as usize] - min[i as usize];
xy.extend((0..1 << bits).map(|k| {
let x = vector.slice()[i as usize].to_f32();
let val = k as f32 / ((1 << bits) - 1) as f32;
let y = bas + val * del;
x * y
}));
#[inline(never)]
fn internal<const BITS: usize, S: ScalarLike>(
dims: usize,
min: &[f32],
max: &[f32],
vector: &[S],
) -> Vec<f32> {
assert!(dims <= 65535);
assert!(dims == min.len());
assert!(dims == max.len());
assert!(dims == vector.len());
let mut table = Vec::<f32>::with_capacity(dims * (1 << BITS));
for i in 0..dims {
let bas = min[i];
let del = (max[i] - min[i]) / ((1 << BITS) - 1) as f32;
for j in 0..1 << BITS {
let x = vector[i].to_f32();
let y = bas + (j as f32) * del;
let value = x * y;
unsafe {
table.as_mut_ptr().add(i * (1 << BITS) + j).write(value);
}
}
}
unsafe {
table.set_len(dims * (1 << BITS));
}
table
}
match bits {
1 => internal::<1, _>(dims as _, min, max, vector.slice()),
2 => internal::<2, _>(dims as _, min, max, vector.slice()),
4 => internal::<4, _>(dims as _, min, max, vector.slice()),
8 => internal::<8, _>(dims as _, min, max, vector.slice()),
_ => unreachable!(),
}
xy
}
fn process(dims: u32, bits: u32, lut: &[f32], rhs: &[u8]) -> Distance {
fn internal<const BITS: u32>(dims: u32, t: &[f32], f: impl Fn(usize) -> usize) -> Distance {
Expand Down Expand Up @@ -359,11 +398,11 @@ impl<S: ScalarLike> OperatorScalarQuantization for VectDot<S> {
fn fscan_preprocess(
dims: u32,
bits: u32,
max: &[f32],
min: &[f32],
max: &[f32],
vector: Borrowed<'_, Self>,
) -> (u32, f32, f32, Vec<u8>) {
let (k, b, t) = quantize::<255>(&Self::preprocess(dims, bits, max, min, vector));
let (k, b, t) = quantize::<255>(&Self::preprocess(dims, bits, min, max, vector));
(dims, k, b, t)
}
fn fscan_process(flut: &(u32, f32, f32, Vec<u8>), codes: &[u8]) -> [Distance; 32] {
Expand All @@ -382,23 +421,46 @@ impl<S: ScalarLike> OperatorScalarQuantization for VectL2<S> {
fn preprocess(
dims: u32,
bits: u32,
max: &[f32],
min: &[f32],
max: &[f32],
vector: Borrowed<'_, Self>,
) -> Vec<f32> {
let mut d2 = Vec::with_capacity(dims as _);
for i in 0..dims {
let bas = min[i as usize];
let del = max[i as usize] - min[i as usize];
d2.extend((0..1 << bits).map(|k| {
let x = vector.slice()[i as usize].to_f32();
let val = k as f32 / ((1 << bits) - 1) as f32;
let y = bas + val * del;
let d = x - y;
d * d
}));
#[inline(never)]
fn internal<const BITS: usize, S: ScalarLike>(
dims: usize,
min: &[f32],
max: &[f32],
vector: &[S],
) -> Vec<f32> {
assert!(dims <= 65535);
assert!(dims == min.len());
assert!(dims == max.len());
assert!(dims == vector.len());
let mut table = Vec::<f32>::with_capacity(dims * (1 << BITS));
for i in 0..dims {
let bas = min[i];
let del = (max[i] - min[i]) / ((1 << BITS) - 1) as f32;
for j in 0..1 << BITS {
let x = vector[i].to_f32();
let y = bas + (j as f32) * del;
let value = (x - y) * (x - y);
unsafe {
table.as_mut_ptr().add(i * (1 << BITS) + j).write(value);
}
}
}
unsafe {
table.set_len(dims * (1 << BITS));
}
table
}
match bits {
1 => internal::<1, _>(dims as _, min, max, vector.slice()),
2 => internal::<2, _>(dims as _, min, max, vector.slice()),
4 => internal::<4, _>(dims as _, min, max, vector.slice()),
8 => internal::<8, _>(dims as _, min, max, vector.slice()),
_ => unreachable!(),
}
d2
}
fn process(dims: u32, bits: u32, lut: &[f32], rhs: &[u8]) -> Distance {
fn internal<const BITS: u32>(dims: u32, t: &[f32], f: impl Fn(usize) -> usize) -> Distance {
Expand Down Expand Up @@ -426,11 +488,11 @@ impl<S: ScalarLike> OperatorScalarQuantization for VectL2<S> {
fn fscan_preprocess(
dims: u32,
bits: u32,
max: &[f32],
min: &[f32],
max: &[f32],
vector: Borrowed<'_, Self>,
) -> (u32, f32, f32, Vec<u8>) {
let (k, b, t) = quantize::<255>(&Self::preprocess(dims, bits, max, min, vector));
let (k, b, t) = quantize::<255>(&Self::preprocess(dims, bits, min, max, vector));
(dims, k, b, t)
}
fn fscan_process(flut: &(u32, f32, f32, Vec<u8>), codes: &[u8]) -> [Distance; 32] {
Expand Down

0 comments on commit 04be9b2

Please sign in to comment.