diff --git a/src/index/id_map.rs b/src/index/id_map.rs index 8831efe..a55fe76 100644 --- a/src/index/id_map.rs +++ b/src/index/id_map.rs @@ -67,6 +67,8 @@ use std::mem; use std::os::raw::c_int; use std::ptr; +use super::IndexImpl; + /// Wrapper for implementing arbitrary ID mapping to an index. /// /// See the [module level documentation] for more information. @@ -364,6 +366,24 @@ where } } +impl IndexImpl { + /// Attempt a dynamic cast of the index to one that is [ID-mapped][1]. + /// + /// [1]: crate::IdMap + pub fn into_id_map(self) -> Result> { + unsafe { + let new_inner = faiss_IndexIDMap_cast(self.inner_ptr()); + if new_inner.is_null() { + Err(Error::BadCast) + } else { + mem::forget(self); + let index_inner = faiss_IndexIDMap_sub_index(new_inner); + Ok(IdMap { inner: new_inner, index_inner, phantom: PhantomData }) + } + } + } +} + #[cfg(test)] mod tests { use super::IdMap; @@ -471,4 +491,12 @@ mod tests { let flat_index: FlatIndexImpl = id_index.try_into_inner().unwrap(); assert_eq!(flat_index.d(), 4); } + + #[test] + fn index_impl_to_id_map() { + let index = index_factory(4, "IDMap,Flat", MetricType::L2).unwrap(); + let id_map = index.into_id_map().unwrap(); + + assert_eq!(id_map.d(), 4); + } }