diff --git a/c_api/IndexIVF_c_ex.cpp b/c_api/IndexIVF_c_ex.cpp index 1fbec3aec0..5b72e15c94 100644 --- a/c_api/IndexIVF_c_ex.cpp +++ b/c_api/IndexIVF_c_ex.cpp @@ -12,6 +12,7 @@ #include "macros_impl.h" using faiss::IndexIVF; +using faiss::SearchParametersIVF; int faiss_IndexIVF_set_direct_map(FaissIndexIVF* index, int direct_map_type) { try { @@ -20,3 +21,14 @@ int faiss_IndexIVF_set_direct_map(FaissIndexIVF* index, int direct_map_type) { } CATCH_AND_HANDLE } + +int faiss_SearchParametersIVF_new_with_sel( + FaissSearchParametersIVF** p_sp, + FaissIDSelector* sel) { + try { + SearchParametersIVF* sp = new SearchParametersIVF; + sp->sel = reinterpret_cast(sel); + *p_sp = reinterpret_cast(sp); + } + CATCH_AND_HANDLE +} \ No newline at end of file diff --git a/c_api/IndexIVF_c_ex.h b/c_api/IndexIVF_c_ex.h index 49e7bbfd0c..e82c5c106f 100644 --- a/c_api/IndexIVF_c_ex.h +++ b/c_api/IndexIVF_c_ex.h @@ -23,6 +23,10 @@ int faiss_IndexIVF_set_direct_map( FaissIndexIVF* index, int direct_map_type); +int faiss_SearchParametersIVF_new_with_sel( + FaissSearchParametersIVF** p_sp, + FaissIDSelector* sel); + #ifdef __cplusplus } #endif diff --git a/faiss/IndexIDMap.cpp b/faiss/IndexIDMap.cpp index 1417b91fdd..ecac7ed454 100644 --- a/faiss/IndexIDMap.cpp +++ b/faiss/IndexIDMap.cpp @@ -16,7 +16,6 @@ #include #include -#include #include #include @@ -71,6 +70,27 @@ void IndexIDMapTemplate::add_with_ids( this->ntotal = index->ntotal; } +namespace { + +/// RAII object to reset the IDSelector in the params object +struct ScopedSelChange { + SearchParameters* params = nullptr; + IDSelector* old_sel = nullptr; + + void set(SearchParameters* params, IDSelector* new_sel) { + this->params = params; + old_sel = params->sel; + params->sel = new_sel; + } + ~ScopedSelChange() { + if (params) { + params->sel = old_sel; + } + } +}; + +} // namespace + template void IndexIDMapTemplate::search( idx_t n, @@ -79,9 +99,26 @@ void IndexIDMapTemplate::search( typename IndexT::distance_t* distances, idx_t* labels, const SearchParameters* params) const { - FAISS_THROW_IF_NOT_MSG( - !params, "search params not supported for this index"); - index->search(n, x, k, distances, labels); + IDSelectorTranslated this_idtrans(this->id_map, nullptr); + ScopedSelChange sel_change; + + if (params && params->sel) { + auto idtrans = dynamic_cast(params->sel); + + if (!idtrans) { + /* + FAISS_THROW_IF_NOT_MSG( + idtrans, + "IndexIDMap requires an IDSelectorTranslated on input"); + */ + // then make an idtrans and force it into the SearchParameters + // (hence the const_cast) + auto params_non_const = const_cast(params); + this_idtrans.sel = params->sel; + sel_change.set(params_non_const, &this_idtrans); + } + } + index->search(n, x, k, distances, labels, params); idx_t* li = labels; #pragma omp parallel for for (idx_t i = 0; i < n * k; i++) { @@ -106,26 +143,10 @@ void IndexIDMapTemplate::range_search( } } -namespace { - -struct IDTranslatedSelector : IDSelector { - const std::vector& id_map; - const IDSelector& sel; - IDTranslatedSelector( - const std::vector& id_map, - const IDSelector& sel) - : id_map(id_map), sel(sel) {} - bool is_member(idx_t id) const override { - return sel.is_member(id_map[id]); - } -}; - -} // namespace - template size_t IndexIDMapTemplate::remove_ids(const IDSelector& sel) { // remove in sub-index first - IDTranslatedSelector sel2(id_map, sel); + IDSelectorTranslated sel2(id_map, &sel); size_t nremove = index->remove_ids(sel2); int64_t j = 0; diff --git a/faiss/IndexIDMap.h b/faiss/IndexIDMap.h index 0643366c98..2297e5e51c 100644 --- a/faiss/IndexIDMap.h +++ b/faiss/IndexIDMap.h @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -102,4 +103,25 @@ struct IndexIDMap2Template : IndexIDMapTemplate { using IndexIDMap2 = IndexIDMap2Template; using IndexBinaryIDMap2 = IndexIDMap2Template; +// IDSelector that translates the ids using an IDMap +struct IDSelectorTranslated : IDSelector { + const std::vector& id_map; + const IDSelector* sel; + + IDSelectorTranslated( + const std::vector& id_map, + const IDSelector* sel) + : id_map(id_map), sel(sel) {} + + IDSelectorTranslated(IndexBinaryIDMap& index_idmap, const IDSelector* sel) + : id_map(index_idmap.id_map), sel(sel) {} + + IDSelectorTranslated(IndexIDMap& index_idmap, const IDSelector* sel) + : id_map(index_idmap.id_map), sel(sel) {} + + bool is_member(idx_t id) const override { + return sel->is_member(id_map[id]); + } +}; + } // namespace faiss diff --git a/faiss/python/__init__.py b/faiss/python/__init__.py index 92eeb3479b..49343043bd 100644 --- a/faiss/python/__init__.py +++ b/faiss/python/__init__.py @@ -192,6 +192,7 @@ def replacement_function(*args): add_ref_in_constructor(IDSelectorAnd, slice(2)) add_ref_in_constructor(IDSelectorOr, slice(2)) add_ref_in_constructor(IDSelectorXOr, slice(2)) +add_ref_in_constructor(IDSelectorTranslated, slice(2)) # seems really marginal... # remove_ref_from_method(IndexReplicas, 'removeIndex', 0) diff --git a/faiss/python/swigfaiss.swig b/faiss/python/swigfaiss.swig index 04906f69fe..9133683330 100644 --- a/faiss/python/swigfaiss.swig +++ b/faiss/python/swigfaiss.swig @@ -494,11 +494,6 @@ void gpu_sync_all_devices() %template(IndexBinaryReplicas) faiss::IndexReplicasTemplate; %include -%include -%template(IndexIDMap) faiss::IndexIDMapTemplate; -%template(IndexBinaryIDMap) faiss::IndexIDMapTemplate; -%template(IndexIDMap2) faiss::IndexIDMap2Template; -%template(IndexBinaryIDMap2) faiss::IndexIDMap2Template; %include @@ -513,6 +508,13 @@ void gpu_sync_all_devices() %include %include +%include +%template(IndexIDMap) faiss::IndexIDMapTemplate; +%template(IndexBinaryIDMap) faiss::IndexIDMapTemplate; +%template(IndexIDMap2) faiss::IndexIDMap2Template; +%template(IndexBinaryIDMap2) faiss::IndexIDMap2Template; + + %include #ifdef GPU_WRAPPER diff --git a/tests/test_search_params.py b/tests/test_search_params.py index bd7e813bf9..2c76f151dc 100644 --- a/tests/test_search_params.py +++ b/tests/test_search_params.py @@ -101,17 +101,17 @@ def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METR sel = faiss.IDSelectorNot(faiss.IDSelectorBatch(inverse_subset)) elif id_selector_type == "or": sel = faiss.IDSelectorOr( - faiss.IDSelectorBatch(lhs_subset), + faiss.IDSelectorBatch(lhs_subset), faiss.IDSelectorBatch(rhs_subset) ) elif id_selector_type == "and": sel = faiss.IDSelectorAnd( - faiss.IDSelectorBatch(lhs_subset), + faiss.IDSelectorBatch(lhs_subset), faiss.IDSelectorBatch(rhs_subset) ) elif id_selector_type == "xor": sel = faiss.IDSelectorXOr( - faiss.IDSelectorBatch(lhs_subset), + faiss.IDSelectorBatch(lhs_subset), faiss.IDSelectorBatch(rhs_subset) ) else: @@ -181,7 +181,7 @@ def test_Flat_id_bitmap(self): def test_Flat_id_not(self): self.do_test_id_selector("Flat", id_selector_type="not") - + def test_Flat_id_or(self): self.do_test_id_selector("Flat", id_selector_type="or") @@ -220,6 +220,41 @@ def do_test_id_selector_weak(self, index_key): def test_HSNW(self): self.do_test_id_selector_weak("HNSW") + def test_idmap(self): + ds = datasets.SyntheticDataset(32, 100, 100, 20) + rs = np.random.RandomState(123) + ids = rs.choice(10000, size=100, replace=False) + mask = ids % 2 == 0 + index = faiss.index_factory(ds.d, "IDMap,SQ8") + index.train(ds.get_train()) + + # ref result + index.add_with_ids(ds.get_database()[mask], ids[mask]) + Dref, Iref = index.search(ds.get_queries(), 10) + + # with selector + index.reset() + index.add_with_ids(ds.get_database(), ids) + + valid_ids = ids[mask] + sel = faiss.IDSelectorTranslated( + index, faiss.IDSelectorBatch(valid_ids)) + + Dnew, Inew = index.search( + ds.get_queries(), 10, + params=faiss.SearchParameters(sel=sel) + ) + np.testing.assert_array_equal(Iref, Inew) + np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5) + + # let the IDMap::search add the translation... + Dnew, Inew = index.search( + ds.get_queries(), 10, + params=faiss.SearchParameters(sel=faiss.IDSelectorBatch(valid_ids)) + ) + np.testing.assert_array_equal(Iref, Inew) + np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5) + class TestSearchParams(unittest.TestCase):