Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[faiss v1.7.1] Added high-level IVFFlatIndex impl #17

Merged
merged 7 commits into from
Jun 3, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions faiss-sys/src/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -694,12 +694,7 @@ extern "C" {
index: *mut FaissIndexFlat1D,
) -> ::std::os::raw::c_int;
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct FaissIndexIVFFlat_H {
_unused: [u8; 0],
}
pub type FaissIndexIVFFlat = FaissIndexIVFFlat_H;
pub type FaissIndexIVFFlat = FaissIndex_H;
extern "C" {
pub fn faiss_IndexIVFFlat_free(obj: *mut FaissIndexIVFFlat);
}
Expand Down
234 changes: 234 additions & 0 deletions src/index/ivf_flat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
//! Interface and implementation to IVFFlat index type.

use super::*;

use crate::error::Result;
use crate::faiss_try;
use std::mem;
use std::ptr;

/// Alias for the native implementation of a flat index.
pub type IVFFlatIndex = IVFFlatIndexImpl;

/// Native implementation of a flat index.
#[derive(Debug)]
pub struct IVFFlatIndexImpl {
inner: *mut FaissIndexIVFFlat,
}

unsafe impl Send for IVFFlatIndexImpl {}
unsafe impl Sync for IVFFlatIndexImpl {}

impl CpuIndex for IVFFlatIndexImpl {}

impl Drop for IVFFlatIndexImpl {
fn drop(&mut self) {
unsafe {
faiss_IndexIVFFlat_free(self.inner);
}
}
}

impl IVFFlatIndexImpl {
/// Create a new IVF flat index.
pub fn new(quantizer: flat::FlatIndex, d: u32, nlist: u32, metric: MetricType) -> Result<Self> {
unsafe {
let metric = metric as c_uint;
let mut inner = ptr::null_mut();
faiss_try(faiss_IndexIVFFlat_new_with_metric(
&mut inner,
quantizer.inner_ptr(),
d as usize,
nlist as usize,
metric,
))?;

mem::forget(quantizer); // own_fields default == true
Enet4 marked this conversation as resolved.
Show resolved Hide resolved
Ok(IVFFlatIndexImpl { inner })
}
}

/// Create a new IVF flat index with L2 as the metric type.
pub fn new_l2(quantizer: flat::FlatIndex, d: u32, nlist: u32) -> Result<Self> {
IVFFlatIndexImpl::new(quantizer, d, nlist, MetricType::L2)
}

/// Create a new IVF flat index with IP (inner product) as the metric type.
pub fn new_ip(quantizer: flat::FlatIndex, d: u32, nlist: u32) -> Result<Self> {
IVFFlatIndexImpl::new(quantizer, d, nlist, MetricType::InnerProduct)
}
}

impl NativeIndex for IVFFlatIndexImpl {
fn inner_ptr(&self) -> *mut FaissIndex {
self.inner
}
}

impl FromInnerPtr for IVFFlatIndexImpl {
unsafe fn from_inner_ptr(inner_ptr: *mut FaissIndex) -> Self {
IVFFlatIndexImpl {
inner: inner_ptr as *mut FaissIndexIVFFlat,
}
}
}

impl_native_index!(IVFFlatIndex);

impl_native_index_clone!(IVFFlatIndex);

impl ConcurrentIndex for IVFFlatIndexImpl {
fn assign(&self, query: &[f32], k: usize) -> Result<AssignSearchResult> {
unsafe {
let nq = query.len() / self.d() as usize;
let mut out_labels = vec![Idx::none(); k * nq];
faiss_try(faiss_Index_assign(
self.inner,
nq as idx_t,
query.as_ptr(),
out_labels.as_mut_ptr() as *mut _,
k as i64,
))?;
Ok(AssignSearchResult { labels: out_labels })
}
}
fn search(&self, query: &[f32], k: usize) -> Result<SearchResult> {
unsafe {
let nq = query.len() / self.d() as usize;
let mut distances = vec![0_f32; k * nq];
let mut labels = vec![Idx::none(); k * nq];
faiss_try(faiss_Index_search(
self.inner,
nq as idx_t,
query.as_ptr(),
k as idx_t,
distances.as_mut_ptr(),
labels.as_mut_ptr() as *mut _,
))?;
Ok(SearchResult { distances, labels })
}
}
fn range_search(&self, query: &[f32], radius: f32) -> Result<RangeSearchResult> {
unsafe {
let nq = (query.len() / self.d() as usize) as idx_t;
let mut p_res: *mut FaissRangeSearchResult = ptr::null_mut();
faiss_try(faiss_RangeSearchResult_new(&mut p_res, nq))?;
faiss_try(faiss_Index_range_search(
self.inner,
nq,
query.as_ptr(),
radius,
p_res,
))?;
Ok(RangeSearchResult { inner: p_res })
}
}
}

impl IndexImpl {
/// Attempt a dynamic cast of an index to the IVF flat index type.
pub fn into_ivf_flat(self) -> Result<IVFFlatIndexImpl> {
unsafe {
let new_inner = faiss_IndexIVFFlat_cast(self.inner_ptr());
if new_inner.is_null() {
Err(Error::BadCast)
} else {
mem::forget(self);
Ok(IVFFlatIndexImpl { inner: new_inner })
}
}
}
}

#[cfg(test)]
mod tests {

use super::IVFFlatIndexImpl;
use crate::index::flat::FlatIndexImpl;
use crate::index::{index_factory, ConcurrentIndex, Idx, Index};
use crate::MetricType;

const D: u32 = 8;

#[test]
// #[ignore]
fn index_search() {
let q = FlatIndexImpl::new_l2(D).unwrap();
let mut index = IVFFlatIndexImpl::new_l2(q, D, 1).unwrap();
assert_eq!(index.d(), D);
assert_eq!(index.ntotal(), 0);
let some_data = &[
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 4.,
-4., -8., 1., 1., 2., 4., -1., 8., 8., 10., -10., -10., 10., -10., 10., 16., 16., 32.,
25., 20., 20., 40., 15.,
];
index.train(some_data).unwrap();
index.add(some_data).unwrap();
assert_eq!(index.ntotal(), 5);

let my_query = [0.; D as usize];
let result = index.search(&my_query, 3).unwrap();
assert_eq!(result.labels.len(), 3);
assert!(result.labels.into_iter().all(Idx::is_some));
assert_eq!(result.distances.len(), 3);
assert!(result.distances.iter().all(|x| *x > 0.));

let my_query = [100.; D as usize];
// flat index can be used behind an immutable ref
let result = (&index).search(&my_query, 3).unwrap();
assert_eq!(result.labels.len(), 3);
assert!(result.labels.into_iter().all(Idx::is_some));
assert_eq!(result.distances.len(), 3);
assert!(result.distances.iter().all(|x| *x > 0.));

index.reset().unwrap();
assert_eq!(index.ntotal(), 0);
}

#[test]
fn index_assign() {
let q = FlatIndexImpl::new_l2(D).unwrap();
let mut index = IVFFlatIndexImpl::new_l2(q, D, 1).unwrap();
assert_eq!(index.d(), D);
assert_eq!(index.ntotal(), 0);
let some_data = &[
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 4.,
-4., -8., 1., 1., 2., 4., -1., 8., 8., 10., -10., -10., 10., -10., 10., 16., 16., 32.,
25., 20., 20., 40., 15.,
];
index.train(some_data).unwrap();
index.add(some_data).unwrap();
assert_eq!(index.ntotal(), 5);

let my_query = [0.; D as usize];
let result = index.assign(&my_query, 3).unwrap();
assert_eq!(result.labels.len(), 3);
assert!(result.labels.into_iter().all(Idx::is_some));

let my_query = [100.; D as usize];
// flat index can be used behind an immutable ref
let result = (&index).assign(&my_query, 3).unwrap();
assert_eq!(result.labels.len(), 3);
assert!(result.labels.into_iter().all(Idx::is_some));

index.reset().unwrap();
assert_eq!(index.ntotal(), 0);
}

#[test]
fn ivf_flat_index_from_cast() {
let mut index = index_factory(8, "IVF1,Flat", MetricType::L2).unwrap();
let some_data = &[
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., 120., 100.,
100., 105., -100., 100., 100., 105.,
];
index.train(some_data).unwrap();
index.add(some_data).unwrap();
assert_eq!(index.ntotal(), 5);

let index: IVFFlatIndexImpl = index.into_ivf_flat().unwrap();
assert_eq!(index.is_trained(), true);
assert_eq!(index.ntotal(), 5);
}
}
1 change: 1 addition & 0 deletions src/index/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use faiss_sys::*;
pub mod flat;
pub mod id_map;
pub mod io;
pub mod ivf_flat;
pub mod lsh;

#[cfg(feature = "gpu")]
Expand Down