From 69139a4f422611bc23d49df045596018bb64b4f0 Mon Sep 17 00:00:00 2001 From: usamoi Date: Wed, 28 Aug 2024 14:49:31 +0800 Subject: [PATCH] feat: compute all distances from vectors to centroids in IVF parallelly Signed-off-by: usamoi --- crates/ivf/src/ivf_naive.rs | 27 +++++++++++++++++++++++---- crates/ivf/src/ivf_residual.rs | 27 +++++++++++++++++++++++---- crates/rabitq/src/lib.rs | 30 ++++++++++++++++++++++++------ 3 files changed, 70 insertions(+), 14 deletions(-) diff --git a/crates/ivf/src/ivf_naive.rs b/crates/ivf/src/ivf_naive.rs index d55877843..bbada9ffb 100644 --- a/crates/ivf/src/ivf_naive.rs +++ b/crates/ivf/src/ivf_naive.rs @@ -12,6 +12,8 @@ use k_means::k_means; use k_means::k_means_lookup; use k_means::k_means_lookup_many; use quantization::Quantization; +use rayon::iter::IntoParallelIterator; +use rayon::iter::ParallelIterator; use std::fs::create_dir; use std::path::Path; use stoppable_rayon as rayon; @@ -108,10 +110,27 @@ fn from_nothing( rayon::check(); let centroids = k_means(nlist as usize, samples, spherical_centroids); rayon::check(); - let mut ls = vec![Vec::new(); nlist as usize]; - for i in 0..collection.len() { - ls[k_means_lookup(O::interpret(collection.vector(i)), ¢roids)].push(i); - } + let ls = (0..collection.len()) + .into_par_iter() + .fold( + || vec![Vec::new(); nlist as usize], + |mut state, i| { + state[k_means_lookup(O::interpret(collection.vector(i)), ¢roids)].push(i); + state + }, + ) + .reduce( + || vec![Vec::new(); nlist as usize], + |lhs, rhs| { + std::iter::zip(lhs, rhs) + .map(|(lhs, rhs)| { + let mut x = lhs; + x.extend(rhs); + x + }) + .collect() + }, + ); let mut offsets = vec![0u32; nlist as usize + 1]; for i in 0..nlist { offsets[i as usize + 1] = offsets[i as usize] + ls[i as usize].len() as u32; diff --git a/crates/ivf/src/ivf_residual.rs b/crates/ivf/src/ivf_residual.rs index 888538d0f..8ae3cf36d 100644 --- a/crates/ivf/src/ivf_residual.rs +++ b/crates/ivf/src/ivf_residual.rs @@ -12,6 +12,8 @@ use k_means::k_means; use k_means::k_means_lookup; use k_means::k_means_lookup_many; use quantization::Quantization; +use rayon::iter::IntoParallelIterator; +use rayon::iter::ParallelIterator; use std::fs::create_dir; use std::path::Path; use stoppable_rayon as rayon; @@ -110,10 +112,27 @@ fn from_nothing( rayon::check(); let centroids = k_means(nlist as usize, samples, spherical_centroids); rayon::check(); - let mut ls = vec![Vec::new(); nlist as usize]; - for i in 0..collection.len() { - ls[k_means_lookup(O::interpret(collection.vector(i)), ¢roids)].push(i); - } + let ls = (0..collection.len()) + .into_par_iter() + .fold( + || vec![Vec::new(); nlist as usize], + |mut state, i| { + state[k_means_lookup(O::interpret(collection.vector(i)), ¢roids)].push(i); + state + }, + ) + .reduce( + || vec![Vec::new(); nlist as usize], + |lhs, rhs| { + std::iter::zip(lhs, rhs) + .map(|(lhs, rhs)| { + let mut x = lhs; + x.extend(rhs); + x + }) + .collect() + }, + ); let mut offsets = vec![0u32; nlist as usize + 1]; for i in 0..nlist { offsets[i as usize + 1] = offsets[i as usize] + ls[i as usize].len() as u32; diff --git a/crates/rabitq/src/lib.rs b/crates/rabitq/src/lib.rs index d72c764f9..54271ffd6 100644 --- a/crates/rabitq/src/lib.rs +++ b/crates/rabitq/src/lib.rs @@ -18,6 +18,7 @@ use common::mmap_array::MmapArray; use common::remap::RemappedCollection; use common::vec2::Vec2; use k_means::{k_means, k_means_lookup, k_means_lookup_many}; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; use std::fs::create_dir; use std::path::Path; use stoppable_rayon as rayon; @@ -36,7 +37,7 @@ impl Rabitq { pub fn create( path: impl AsRef, options: IndexOptions, - source: &(impl Vectors> + Collection + Source), + source: &(impl Vectors> + Collection + Source + Sync), ) -> Self { let remapped = RemappedCollection::from_source(source); from_nothing(path, options, &remapped) @@ -102,7 +103,7 @@ impl Rabitq { fn from_nothing( path: impl AsRef, options: IndexOptions, - collection: &(impl Vectors> + Collection), + collection: &(impl Vectors> + Collection + Sync), ) -> Rabitq { create_dir(path.as_ref()).unwrap(); let RabitqIndexingOptions { @@ -133,10 +134,27 @@ fn from_nothing( rayon::check(); let centroids: Vec2 = k_means(nlist as usize, samples, spherical_centroids); rayon::check(); - let mut ls = vec![Vec::new(); nlist as usize]; - for i in 0..collection.len() { - ls[k_means_lookup(O::cast(collection.vector(i)), ¢roids)].push(i); - } + let ls = (0..collection.len()) + .into_par_iter() + .fold( + || vec![Vec::new(); nlist as usize], + |mut state, i| { + state[k_means_lookup(O::cast(collection.vector(i)), ¢roids)].push(i); + state + }, + ) + .reduce( + || vec![Vec::new(); nlist as usize], + |lhs, rhs| { + std::iter::zip(lhs, rhs) + .map(|(lhs, rhs)| { + let mut x = lhs; + x.extend(rhs); + x + }) + .collect() + }, + ); let mut offsets = vec![0u32; nlist as usize + 1]; for i in 0..nlist { offsets[i as usize + 1] = offsets[i as usize] + ls[i as usize].len() as u32;