Skip to content

Commit

Permalink
feat: compute all distances from vectors to centroids in IVF parallelly
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <usamoi@outlook.com>
  • Loading branch information
usamoi committed Aug 28, 2024
1 parent c302125 commit 69139a4
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 14 deletions.
27 changes: 23 additions & 4 deletions crates/ivf/src/ivf_naive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use k_means::k_means;
use k_means::k_means_lookup;
use k_means::k_means_lookup_many;
use quantization::Quantization;
use rayon::iter::IntoParallelIterator;
use rayon::iter::ParallelIterator;
use std::fs::create_dir;
use std::path::Path;
use stoppable_rayon as rayon;
Expand Down Expand Up @@ -108,10 +110,27 @@ fn from_nothing<O: Op>(
rayon::check();
let centroids = k_means(nlist as usize, samples, spherical_centroids);
rayon::check();
let mut ls = vec![Vec::new(); nlist as usize];
for i in 0..collection.len() {
ls[k_means_lookup(O::interpret(collection.vector(i)), &centroids)].push(i);
}
let ls = (0..collection.len())
.into_par_iter()
.fold(
|| vec![Vec::new(); nlist as usize],
|mut state, i| {
state[k_means_lookup(O::interpret(collection.vector(i)), &centroids)].push(i);
state
},
)
.reduce(
|| vec![Vec::new(); nlist as usize],
|lhs, rhs| {
std::iter::zip(lhs, rhs)
.map(|(lhs, rhs)| {
let mut x = lhs;
x.extend(rhs);
x
})
.collect()
},
);
let mut offsets = vec![0u32; nlist as usize + 1];
for i in 0..nlist {
offsets[i as usize + 1] = offsets[i as usize] + ls[i as usize].len() as u32;
Expand Down
27 changes: 23 additions & 4 deletions crates/ivf/src/ivf_residual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use k_means::k_means;
use k_means::k_means_lookup;
use k_means::k_means_lookup_many;
use quantization::Quantization;
use rayon::iter::IntoParallelIterator;
use rayon::iter::ParallelIterator;
use std::fs::create_dir;
use std::path::Path;
use stoppable_rayon as rayon;
Expand Down Expand Up @@ -110,10 +112,27 @@ fn from_nothing<O: Op>(
rayon::check();
let centroids = k_means(nlist as usize, samples, spherical_centroids);
rayon::check();
let mut ls = vec![Vec::new(); nlist as usize];
for i in 0..collection.len() {
ls[k_means_lookup(O::interpret(collection.vector(i)), &centroids)].push(i);
}
let ls = (0..collection.len())
.into_par_iter()
.fold(
|| vec![Vec::new(); nlist as usize],
|mut state, i| {
state[k_means_lookup(O::interpret(collection.vector(i)), &centroids)].push(i);
state
},
)
.reduce(
|| vec![Vec::new(); nlist as usize],
|lhs, rhs| {
std::iter::zip(lhs, rhs)
.map(|(lhs, rhs)| {
let mut x = lhs;
x.extend(rhs);
x
})
.collect()
},
);
let mut offsets = vec![0u32; nlist as usize + 1];
for i in 0..nlist {
offsets[i as usize + 1] = offsets[i as usize] + ls[i as usize].len() as u32;
Expand Down
30 changes: 24 additions & 6 deletions crates/rabitq/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use common::mmap_array::MmapArray;
use common::remap::RemappedCollection;
use common::vec2::Vec2;
use k_means::{k_means, k_means_lookup, k_means_lookup_many};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use std::fs::create_dir;
use std::path::Path;
use stoppable_rayon as rayon;
Expand All @@ -36,7 +37,7 @@ impl<O: Op> Rabitq<O> {
pub fn create(
path: impl AsRef<Path>,
options: IndexOptions,
source: &(impl Vectors<Owned<O>> + Collection + Source),
source: &(impl Vectors<Owned<O>> + Collection + Source + Sync),
) -> Self {
let remapped = RemappedCollection::from_source(source);
from_nothing(path, options, &remapped)
Expand Down Expand Up @@ -102,7 +103,7 @@ impl<O: Op> Rabitq<O> {
fn from_nothing<O: Op>(
path: impl AsRef<Path>,
options: IndexOptions,
collection: &(impl Vectors<Owned<O>> + Collection),
collection: &(impl Vectors<Owned<O>> + Collection + Sync),
) -> Rabitq<O> {
create_dir(path.as_ref()).unwrap();
let RabitqIndexingOptions {
Expand Down Expand Up @@ -133,10 +134,27 @@ fn from_nothing<O: Op>(
rayon::check();
let centroids: Vec2<f32> = k_means(nlist as usize, samples, spherical_centroids);
rayon::check();
let mut ls = vec![Vec::new(); nlist as usize];
for i in 0..collection.len() {
ls[k_means_lookup(O::cast(collection.vector(i)), &centroids)].push(i);
}
let ls = (0..collection.len())
.into_par_iter()
.fold(
|| vec![Vec::new(); nlist as usize],
|mut state, i| {
state[k_means_lookup(O::cast(collection.vector(i)), &centroids)].push(i);
state
},
)
.reduce(
|| vec![Vec::new(); nlist as usize],
|lhs, rhs| {
std::iter::zip(lhs, rhs)
.map(|(lhs, rhs)| {
let mut x = lhs;
x.extend(rhs);
x
})
.collect()
},
);
let mut offsets = vec![0u32; nlist as usize + 1];
for i in 0..nlist {
offsets[i as usize + 1] = offsets[i as usize] + ls[i as usize].len() as u32;
Expand Down

0 comments on commit 69139a4

Please sign in to comment.