From a69568eed385bda4c90db4da24917b9d5d64ad06 Mon Sep 17 00:00:00 2001 From: Matthijs Douze Date: Fri, 16 Jun 2023 02:12:32 -0700 Subject: [PATCH] Binary cloning and GPU range search (#2916) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2916 Overall better support for binary indexes: - cloning (to CPU and GPU), only for BinaryFlat for now - fix bug in reconstruct_n - range_search_max_results Reviewed By: algoriddle Differential Revision: D46755778 fbshipit-source-id: a8890e63ed284eb6478ed7fa641923955005661a --- contrib/exhaustive_search.py | 39 ++++++++---- faiss/IndexBinary.cpp | 11 +++- faiss/IndexBinary.h | 27 +++------ faiss/IndexBinaryFromFloat.cpp | 1 + faiss/IndexIDMap.cpp | 17 +++++- faiss/IndexIDMap.h | 4 +- faiss/IndexReplicas.cpp | 45 +++++++------- faiss/IndexShards.cpp | 48 +++++++-------- faiss/clone_index.cpp | 11 ++++ faiss/clone_index.h | 3 + faiss/gpu/GpuCloner.cpp | 73 +++++++++++++++++++++++ faiss/gpu/GpuCloner.h | 22 +++++++ faiss/gpu/test/TestGpuIndexBinaryFlat.cpp | 25 ++++++++ faiss/gpu/test/test_contrib_gpu.py | 38 +++++++++++- faiss/gpu/test/test_multi_gpu.py | 28 ++++++++- faiss/python/class_wrappers.py | 33 ++++++---- faiss/python/gpu_wrappers.py | 6 +- faiss/python/swigfaiss.swig | 6 ++ faiss/utils/utils.cpp | 20 +++++-- faiss/utils/utils.h | 11 ++-- tests/test_contrib.py | 30 ++++++++++ tests/test_index_binary.py | 2 +- 22 files changed, 384 insertions(+), 116 deletions(-) diff --git a/contrib/exhaustive_search.py b/contrib/exhaustive_search.py index 19b452a25d..eadb097fae 100644 --- a/contrib/exhaustive_search.py +++ b/contrib/exhaustive_search.py @@ -60,12 +60,18 @@ def range_search_gpu(xq, r2, index_gpu, index_cpu, gpu_k=1024): - None. In that case, at most gpu_k results will be returned """ nq, d = xq.shape - k = min(index_gpu.ntotal, gpu_k) + is_binary_index = isinstance(index_gpu, faiss.IndexBinary) keep_max = faiss.is_similarity_metric(index_gpu.metric_type) - LOG.debug(f"GPU search {nq} queries with {k=:}") + r2 = int(r2) if is_binary_index else float(r2) + k = min(index_gpu.ntotal, gpu_k) + LOG.debug( + f"GPU search {nq} queries with {k=:} {is_binary_index=:} {keep_max=:}") t0 = time.time() D, I = index_gpu.search(xq, k) t1 = time.time() - t0 + if is_binary_index: + assert d * 8 < 32768 # let's compact the distance matrix + D = D.astype('int16') t2 = 0 lim_remain = None if index_cpu is not None: @@ -79,14 +85,24 @@ def range_search_gpu(xq, r2, index_gpu, index_cpu, gpu_k=1024): if isinstance(index_cpu, np.ndarray): # then it in fact an array that we have to make flat xb = index_cpu - index_cpu = faiss.IndexFlat(d, index_gpu.metric_type) + if is_binary_index: + index_cpu = faiss.IndexBinaryFlat(d * 8) + else: + index_cpu = faiss.IndexFlat(d, index_gpu.metric_type) index_cpu.add(xb) lim_remain, D_remain, I_remain = index_cpu.range_search(xq[mask], r2) + if is_binary_index: + D_remain = D_remain.astype('int16') t2 = time.time() - t0 LOG.debug("combine") t0 = time.time() - combiner = faiss.CombinerRangeKNN(nq, k, float(r2), keep_max) + CombinerRangeKNN = ( + faiss.CombinerRangeKNNint16 if is_binary_index else + faiss.CombinerRangeKNNfloat + ) + + combiner = CombinerRangeKNN(nq, k, r2, keep_max) if True: sp = faiss.swig_ptr combiner.I = sp(I) @@ -101,7 +117,7 @@ def range_search_gpu(xq, r2, index_gpu, index_cpu, gpu_k=1024): L_res = np.empty(nq + 1, dtype='int64') combiner.compute_sizes(sp(L_res)) nres = L_res[-1] - D_res = np.empty(nres, dtype='float32') + D_res = np.empty(nres, dtype=D.dtype) I_res = np.empty(nres, dtype='int64') combiner.write_result(sp(D_res), sp(I_res)) else: @@ -251,6 +267,7 @@ def range_search_max_results(index, query_iterator, radius, """ # TODO: all result manipulations are in python, should move to C++ if perf # critical + is_binary_index = isinstance(index, faiss.IndexBinary) if min_results is None: assert max_results is not None @@ -268,6 +285,8 @@ def range_search_max_results(index, query_iterator, radius, co = faiss.GpuMultipleClonerOptions() co.shard = shard index_gpu = faiss.index_cpu_to_all_gpus(index, co=co, ngpu=ngpu) + else: + index_gpu = None t_start = time.time() t_search = t_post_process = 0 @@ -276,7 +295,8 @@ def range_search_max_results(index, query_iterator, radius, for xqi in query_iterator: t0 = time.time() - if ngpu > 0: + LOG.debug(f"searching {len(xqi)} vectors") + if index_gpu: lims_i, Di, Ii = range_search_gpu(xqi, radius, index_gpu, index) else: lims_i, Di, Ii = index.range_search(xqi, radius) @@ -286,8 +306,7 @@ def range_search_max_results(index, query_iterator, radius, qtot += len(xqi) t1 = time.time() - if xqi.dtype != np.float32: - # for binary indexes + if is_binary_index: # weird Faiss quirk that returns floats for Hamming distances Di = Di.astype('int16') @@ -299,7 +318,7 @@ def range_search_max_results(index, query_iterator, radius, (totres, max_results)) radius, totres = apply_maxres( res_batches, min_results, - keep_max=faiss.is_similarity_metric(index.metric_type) + keep_max=index.metric_type == faiss.METRIC_INNER_PRODUCT ) t2 = time.time() t_search += t1 - t0 @@ -315,7 +334,7 @@ def range_search_max_results(index, query_iterator, radius, if clip_to_min and totres > min_results: radius, totres = apply_maxres( res_batches, min_results, - keep_max=faiss.is_similarity_metric(index.metric_type) + keep_max=index.metric_type == faiss.METRIC_INNER_PRODUCT ) nres = np.hstack([nres_i for nres_i, dis_i, ids_i in res_batches]) diff --git a/faiss/IndexBinary.cpp b/faiss/IndexBinary.cpp index 8d9eef3099..22d8bc5e1b 100644 --- a/faiss/IndexBinary.cpp +++ b/faiss/IndexBinary.cpp @@ -15,6 +15,11 @@ namespace faiss { +IndexBinary::IndexBinary(idx_t d, MetricType metric) + : d(d), code_size(d / 8), metric_type(metric) { + FAISS_THROW_IF_NOT(d % 8 == 0); +} + IndexBinary::~IndexBinary() {} void IndexBinary::train(idx_t, const uint8_t*) { @@ -51,7 +56,7 @@ void IndexBinary::reconstruct(idx_t, uint8_t*) const { void IndexBinary::reconstruct_n(idx_t i0, idx_t ni, uint8_t* recons) const { for (idx_t i = 0; i < ni; i++) { - reconstruct(i0 + i, recons + i * d); + reconstruct(i0 + i, recons + i * code_size); } } @@ -70,10 +75,10 @@ void IndexBinary::search_and_reconstruct( for (idx_t j = 0; j < k; ++j) { idx_t ij = i * k + j; idx_t key = labels[ij]; - uint8_t* reconstructed = recons + ij * d; + uint8_t* reconstructed = recons + ij * code_size; if (key < 0) { // Fill with NaNs - memset(reconstructed, -1, sizeof(*reconstructed) * d); + memset(reconstructed, -1, code_size); } else { reconstruct(key, reconstructed); } diff --git a/faiss/IndexBinary.h b/faiss/IndexBinary.h index 35a0aff79c..a026f34177 100644 --- a/faiss/IndexBinary.h +++ b/faiss/IndexBinary.h @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -// -*- c++ -*- - #ifndef FAISS_INDEX_BINARY_H #define FAISS_INDEX_BINARY_H @@ -16,7 +14,6 @@ #include #include -#include namespace faiss { @@ -35,27 +32,19 @@ struct IndexBinary { using component_t = uint8_t; using distance_t = int32_t; - int d; ///< vector dimension - int code_size; ///< number of bytes per vector ( = d / 8 ) - idx_t ntotal; ///< total nb of indexed vectors - bool verbose; ///< verbosity level + int d = 0; ///< vector dimension + int code_size = 0; ///< number of bytes per vector ( = d / 8 ) + idx_t ntotal = 0; ///< total nb of indexed vectors + bool verbose = false; ///< verbosity level /// set if the Index does not require training, or if training is done /// already - bool is_trained; + bool is_trained = true; /// type of metric this index uses for search - MetricType metric_type; - - explicit IndexBinary(idx_t d = 0, MetricType metric = METRIC_L2) - : d(d), - code_size(d / 8), - ntotal(0), - verbose(false), - is_trained(true), - metric_type(metric) { - FAISS_THROW_IF_NOT(d % 8 == 0); - } + MetricType metric_type = METRIC_L2; + + explicit IndexBinary(idx_t d = 0, MetricType metric = METRIC_L2); virtual ~IndexBinary(); diff --git a/faiss/IndexBinaryFromFloat.cpp b/faiss/IndexBinaryFromFloat.cpp index 490320f9f1..e396231f4c 100644 --- a/faiss/IndexBinaryFromFloat.cpp +++ b/faiss/IndexBinaryFromFloat.cpp @@ -9,6 +9,7 @@ #include +#include #include #include #include diff --git a/faiss/IndexIDMap.cpp b/faiss/IndexIDMap.cpp index ecac7ed454..fdd7fcd734 100644 --- a/faiss/IndexIDMap.cpp +++ b/faiss/IndexIDMap.cpp @@ -21,18 +21,31 @@ namespace faiss { +namespace { + +// IndexBinary needs to update the code_size when d is set... + +void sync_d(Index* index) {} + +void sync_d(IndexBinary* index) { + FAISS_THROW_IF_NOT(index->d % 8 == 0); + index->code_size = index->d / 8; +} + +} // anonymous namespace + /***************************************************** * IndexIDMap implementation *******************************************************/ template -IndexIDMapTemplate::IndexIDMapTemplate(IndexT* index) - : index(index), own_fields(false) { +IndexIDMapTemplate::IndexIDMapTemplate(IndexT* index) : index(index) { FAISS_THROW_IF_NOT_MSG(index->ntotal == 0, "index must be empty on input"); this->is_trained = index->is_trained; this->metric_type = index->metric_type; this->verbose = index->verbose; this->d = index->d; + sync_d(this); } template diff --git a/faiss/IndexIDMap.h b/faiss/IndexIDMap.h index 2297e5e51c..2d16412301 100644 --- a/faiss/IndexIDMap.h +++ b/faiss/IndexIDMap.h @@ -22,8 +22,8 @@ struct IndexIDMapTemplate : IndexT { using component_t = typename IndexT::component_t; using distance_t = typename IndexT::distance_t; - IndexT* index; ///! the sub-index - bool own_fields; ///! whether pointers are deleted in destructo + IndexT* index = nullptr; ///! the sub-index + bool own_fields = false; ///! whether pointers are deleted in destructo std::vector id_map; explicit IndexIDMapTemplate(IndexT* index); diff --git a/faiss/IndexReplicas.cpp b/faiss/IndexReplicas.cpp index 39a07f0996..8295f34a60 100644 --- a/faiss/IndexReplicas.cpp +++ b/faiss/IndexReplicas.cpp @@ -12,17 +12,34 @@ namespace faiss { +namespace { + +// IndexBinary needs to update the code_size when d is set... + +void sync_d(Index* index) {} + +void sync_d(IndexBinary* index) { + FAISS_THROW_IF_NOT(index->d % 8 == 0); + index->code_size = index->d / 8; +} + +} // anonymous namespace + template IndexReplicasTemplate::IndexReplicasTemplate(bool threaded) : ThreadedIndex(threaded) {} template IndexReplicasTemplate::IndexReplicasTemplate(idx_t d, bool threaded) - : ThreadedIndex(d, threaded) {} + : ThreadedIndex(d, threaded) { + sync_d(this); +} template IndexReplicasTemplate::IndexReplicasTemplate(int d, bool threaded) - : ThreadedIndex(d, threaded) {} + : ThreadedIndex(d, threaded) { + sync_d(this); +} template void IndexReplicasTemplate::onAfterAddIndex(IndexT* index) { @@ -168,6 +185,8 @@ void IndexReplicasTemplate::syncWithSubIndexes() { } auto firstIndex = this->at(0); + this->d = firstIndex->d; + sync_d(this); this->metric_type = firstIndex->metric_type; this->is_trained = firstIndex->is_trained; this->ntotal = firstIndex->ntotal; @@ -181,28 +200,6 @@ void IndexReplicasTemplate::syncWithSubIndexes() { } } -// No metric_type for IndexBinary -template <> -void IndexReplicasTemplate::syncWithSubIndexes() { - if (!this->count()) { - this->is_trained = false; - this->ntotal = 0; - - return; - } - - auto firstIndex = this->at(0); - this->is_trained = firstIndex->is_trained; - this->ntotal = firstIndex->ntotal; - - for (int i = 1; i < this->count(); ++i) { - auto index = this->at(i); - FAISS_THROW_IF_NOT(this->d == index->d); - FAISS_THROW_IF_NOT(this->is_trained == index->is_trained); - FAISS_THROW_IF_NOT(this->ntotal == index->ntotal); - } -} - // explicit instantiations template struct IndexReplicasTemplate; template struct IndexReplicasTemplate; diff --git a/faiss/IndexShards.cpp b/faiss/IndexShards.cpp index e32b57eff2..e8074469fd 100644 --- a/faiss/IndexShards.cpp +++ b/faiss/IndexShards.cpp @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -// -*- c++ -*- - #include #include @@ -22,6 +20,15 @@ namespace faiss { // subroutines namespace { +// IndexBinary needs to update the code_size when d is set... + +void sync_d(Index* index) {} + +void sync_d(IndexBinary* index) { + FAISS_THROW_IF_NOT(index->d % 8 == 0); + index->code_size = index->d / 8; +} + // add translation to all valid labels void translate_labels(int64_t n, idx_t* labels, int64_t translation) { if (translation == 0) @@ -40,20 +47,26 @@ IndexShardsTemplate::IndexShardsTemplate( idx_t d, bool threaded, bool successive_ids) - : ThreadedIndex(d, threaded), successive_ids(successive_ids) {} + : ThreadedIndex(d, threaded), successive_ids(successive_ids) { + sync_d(this); +} template IndexShardsTemplate::IndexShardsTemplate( int d, bool threaded, bool successive_ids) - : ThreadedIndex(d, threaded), successive_ids(successive_ids) {} + : ThreadedIndex(d, threaded), successive_ids(successive_ids) { + sync_d(this); +} template IndexShardsTemplate::IndexShardsTemplate( bool threaded, bool successive_ids) - : ThreadedIndex(threaded), successive_ids(successive_ids) {} + : ThreadedIndex(threaded), successive_ids(successive_ids) { + sync_d(this); +} template void IndexShardsTemplate::onAfterAddIndex(IndexT* index /* unused */) { @@ -78,6 +91,8 @@ void IndexShardsTemplate::syncWithSubIndexes() { } auto firstIndex = this->at(0); + this->d = firstIndex->d; + sync_d(this); this->metric_type = firstIndex->metric_type; this->is_trained = firstIndex->is_trained; this->ntotal = firstIndex->ntotal; @@ -92,29 +107,6 @@ void IndexShardsTemplate::syncWithSubIndexes() { } } -// No metric_type for IndexBinary -template <> -void IndexShardsTemplate::syncWithSubIndexes() { - if (!this->count()) { - this->is_trained = false; - this->ntotal = 0; - - return; - } - - auto firstIndex = this->at(0); - this->is_trained = firstIndex->is_trained; - this->ntotal = firstIndex->ntotal; - - for (int i = 1; i < this->count(); ++i) { - auto index = this->at(i); - FAISS_THROW_IF_NOT(this->d == index->d); - FAISS_THROW_IF_NOT(this->is_trained == index->is_trained); - - this->ntotal += index->ntotal; - } -} - template void IndexShardsTemplate::train(idx_t n, const component_t* x) { auto fn = [n, x](int no, IndexT* index) { diff --git a/faiss/clone_index.cpp b/faiss/clone_index.cpp index 2a6bb30abf..44ab1f7cc3 100644 --- a/faiss/clone_index.cpp +++ b/faiss/clone_index.cpp @@ -17,6 +17,8 @@ #include #include #include +#include +#include #include #include #include @@ -35,6 +37,7 @@ #include #include #include + #include #include @@ -385,4 +388,12 @@ Quantizer* clone_Quantizer(const Quantizer* quant) { FAISS_THROW_MSG("Did not recognize quantizer to clone"); } +IndexBinary* clone_binary_index(const IndexBinary* index) { + if (auto ii = dynamic_cast(index)) { + return new IndexBinaryFlat(*ii); + } else { + FAISS_THROW_MSG("cannot clone this type of index"); + } +} + } // namespace faiss diff --git a/faiss/clone_index.h b/faiss/clone_index.h index 640053ee13..397c0ff5f7 100644 --- a/faiss/clone_index.h +++ b/faiss/clone_index.h @@ -17,6 +17,7 @@ struct Index; struct IndexIVF; struct VectorTransform; struct Quantizer; +struct IndexBinary; /* cloning functions */ Index* clone_index(const Index*); @@ -33,4 +34,6 @@ struct Cloner { Quantizer* clone_Quantizer(const Quantizer* quant); +IndexBinary* clone_binary_index(const IndexBinary* index); + } // namespace faiss diff --git a/faiss/gpu/GpuCloner.cpp b/faiss/gpu/GpuCloner.cpp index a77ee0485d..94b587b2ed 100644 --- a/faiss/gpu/GpuCloner.cpp +++ b/faiss/gpu/GpuCloner.cpp @@ -11,6 +11,7 @@ #include +#include #include #include #include @@ -21,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -483,5 +485,76 @@ Index* GpuProgressiveDimIndexFactory::operator()(int dim) { return index_cpu_to_gpu_multiple(vres, devices, &index, &options); } +/********************************************* + * Cloning binary indexes + *********************************************/ + +faiss::IndexBinary* index_binary_gpu_to_cpu( + const faiss::IndexBinary* gpu_index) { + if (auto ii = dynamic_cast(gpu_index)) { + IndexBinaryFlat* ret = new IndexBinaryFlat(); + ii->copyTo(ret); + return ret; + } else { + FAISS_THROW_MSG("cannot clone this type of index"); + } +} + +faiss::IndexBinary* index_binary_cpu_to_gpu( + GpuResourcesProvider* provider, + int device, + const faiss::IndexBinary* index, + const GpuClonerOptions* options) { + if (auto ii = dynamic_cast(index)) { + GpuIndexBinaryFlatConfig config; + config.device = device; + if (options) { + config.use_raft = options->use_raft; + } + return new GpuIndexBinaryFlat(provider, ii, config); + } else { + FAISS_THROW_MSG("cannot clone this type of index"); + } +} + +faiss::IndexBinary* index_binary_cpu_to_gpu_multiple( + std::vector& provider, + std::vector& devices, + const faiss::IndexBinary* index, + const GpuMultipleClonerOptions* options) { + GpuMultipleClonerOptions defaults; + FAISS_THROW_IF_NOT(devices.size() == provider.size()); + int n = devices.size(); + if (n == 1) { + return index_binary_cpu_to_gpu(provider[0], devices[0], index, options); + } + if (!options) { + options = &defaults; + } + if (options->shard) { + auto* fi = dynamic_cast(index); + FAISS_THROW_IF_NOT_MSG(fi, "only flat index cloning supported"); + IndexBinaryShards* ret = new IndexBinaryShards(true, true); + for (int i = 0; i < n; i++) { + IndexBinaryFlat fig(fi->d); + size_t i0 = i * fi->ntotal / n; + size_t i1 = (i + 1) * fi->ntotal / n; + fig.add(i1 - i0, fi->xb.data() + i0 * fi->code_size); + ret->addIndex(index_binary_cpu_to_gpu( + provider[i], devices[i], &fig, options)); + } + ret->own_indices = true; + return ret; + } else { // replicas + IndexBinaryReplicas* ret = new IndexBinaryReplicas(true); + for (int i = 0; i < n; i++) { + ret->addIndex(index_binary_cpu_to_gpu( + provider[i], devices[i], index, options)); + } + ret->own_indices = true; + return ret; + } +} + } // namespace gpu } // namespace faiss diff --git a/faiss/gpu/GpuCloner.h b/faiss/gpu/GpuCloner.h index 584f1163fc..105e83a0c5 100644 --- a/faiss/gpu/GpuCloner.h +++ b/faiss/gpu/GpuCloner.h @@ -11,10 +11,12 @@ #include #include +#include #include #include #include #include + namespace faiss { namespace gpu { @@ -95,5 +97,25 @@ struct GpuProgressiveDimIndexFactory : ProgressiveDimIndexFactory { virtual ~GpuProgressiveDimIndexFactory() override; }; +/********************************************* + * Cloning binary indexes + *********************************************/ + +faiss::IndexBinary* index_binary_gpu_to_cpu( + const faiss::IndexBinary* gpu_index); + +/// converts any CPU index that can be converted to GPU +faiss::IndexBinary* index_binary_cpu_to_gpu( + GpuResourcesProvider* provider, + int device, + const faiss::IndexBinary* index, + const GpuClonerOptions* options = nullptr); + +faiss::IndexBinary* index_binary_cpu_to_gpu_multiple( + std::vector& provider, + std::vector& devices, + const faiss::IndexBinary* index, + const GpuMultipleClonerOptions* options = nullptr); + } // namespace gpu } // namespace faiss diff --git a/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp b/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp index 85179e64cb..037a579df6 100644 --- a/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +++ b/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -164,6 +165,30 @@ TEST(TestGpuIndexBinaryFlat, LargeIndex) { compareBinaryDist(cpuDist, cpuLabels, gpuDist, gpuLabels, nq, k); } +TEST(TestGpuIndexBinaryFlat, Reconstruct) { + int n = 1000; + std::vector xb(8 * n); + faiss::byte_rand(xb.data(), xb.size(), 123); + std::unique_ptr index( + new faiss::IndexBinaryFlat(64)); + index->add(n, xb.data()); + + std::vector xb3(8 * n); + index->reconstruct_n(0, index->ntotal, xb3.data()); + EXPECT_EQ(xb, xb3); + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + std::unique_ptr index2( + new faiss::gpu::GpuIndexBinaryFlat(&res, index.get())); + + std::vector xb2(8 * n); + + index2->reconstruct_n(0, index->ntotal, xb2.data()); + EXPECT_EQ(xb2, xb3); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); diff --git a/faiss/gpu/test/test_contrib_gpu.py b/faiss/gpu/test/test_contrib_gpu.py index faeeb5d0d8..e6c28ac3aa 100644 --- a/faiss/gpu/test/test_contrib_gpu.py +++ b/faiss/gpu/test/test_contrib_gpu.py @@ -12,7 +12,8 @@ from faiss.contrib import datasets, evaluation, big_batch_search from faiss.contrib.exhaustive_search import knn_ground_truth, \ - range_ground_truth, range_search_gpu + range_ground_truth, range_search_gpu, \ + range_search_max_results, exponential_query_iterator class TestComputeGT(unittest.TestCase): @@ -62,6 +63,41 @@ def test_range_L2(self): def test_range_IP(self): self.do_test_range(faiss.METRIC_INNER_PRODUCT) + def test_max_results_binary(self, ngpu=1): + ds = datasets.SyntheticDataset(64, 1000, 1000, 200) + tobinary = faiss.index_factory(ds.d, "LSHrt") + tobinary.train(ds.get_train()) + index = faiss.IndexBinaryFlat(ds.d) + xb = tobinary.sa_encode(ds.get_database()) + xq = tobinary.sa_encode(ds.get_queries()) + index.add(xb) + + # find a reasonable radius + D, _ = index.search(xq, 10) + radius0 = int(np.median(D[:, -1])) + + # baseline = search with that radius + lims_ref, Dref, Iref = index.range_search(xq, radius0) + + # now see if using just the total number of results, we can get back + # the same result table + query_iterator = exponential_query_iterator(xq) + + radius1, lims_new, Dnew, Inew = range_search_max_results( + index, query_iterator, ds.d // 2, + min_results=Dref.size, clip_to_min=True, + ngpu=1 + ) + + evaluation.check_ref_range_results( + lims_ref, Dref, Iref, + lims_new, Dnew, Inew + ) + + @unittest.skipIf(faiss.get_num_gpus() < 2, "multiple GPU only test") + def test_max_results_binary_multigpu(self): + self.test_max_results_binary(ngpu=2) + class TestBigBatchSearch(unittest.TestCase): diff --git a/faiss/gpu/test/test_multi_gpu.py b/faiss/gpu/test/test_multi_gpu.py index 568d444a6e..95762d4321 100644 --- a/faiss/gpu/test/test_multi_gpu.py +++ b/faiss/gpu/test/test_multi_gpu.py @@ -9,7 +9,7 @@ import faiss from faiss.contrib.datasets import SyntheticDataset - +from faiss.contrib.evaluation import check_ref_knn_with_draws class TestShardedFlat(unittest.TestCase): @@ -98,6 +98,32 @@ def test_sharded_IVFSQ(self): def test_sharded_IVF_HNSW(self): self.do_test_sharded_ivf("IVF1000_HNSW,Flat") + def test_binary_clone(self, ngpu=1, shard=False): + ds = SyntheticDataset(64, 1000, 1000, 200) + tobinary = faiss.index_factory(ds.d, "LSHrt") + tobinary.train(ds.get_train()) + index = faiss.IndexBinaryFlat(ds.d) + xb = tobinary.sa_encode(ds.get_database()) + xq = tobinary.sa_encode(ds.get_queries()) + index.add(xb) + Dref, Iref = index.search(xq, 5) + + co = faiss.GpuMultipleClonerOptions() + co.shard = shard + + # index2 = faiss.index_cpu_to_all_gpus(index, ngpu=ngpu) + res = faiss.StandardGpuResources() + index2 = faiss.GpuIndexBinaryFlat(res, index) + + Dnew, Inew = index2.search(xq, 5) + check_ref_knn_with_draws(Dref, Iref, Dnew, Inew) + + def test_binary_clone_replicas(self): + self.test_binary_clone(ngpu=2, shard=False) + + def test_binary_clone_shards(self): + self.test_binary_clone(ngpu=2, shard=True) + # This class also has a multi-GPU test within class EvalIVFPQAccuracy(unittest.TestCase): diff --git a/faiss/python/class_wrappers.py b/faiss/python/class_wrappers.py index 7d410d4afb..efce359220 100644 --- a/faiss/python/class_wrappers.py +++ b/faiss/python/class_wrappers.py @@ -495,7 +495,7 @@ def replacement_reconstruct_n(self, n0=0, ni=-1, x=None): Reconstructed vectors, size (`ni`, `self.d`), `dtype`=float32 """ if ni == -1: - ni = self.ntotal + ni = self.ntotal - n0 if x is None: x = np.empty((ni, self.d), dtype=np.float32) else: @@ -767,32 +767,43 @@ def handle_IndexBinary(the_class): def replacement_add(self, x): n, d = x.shape x = _check_dtype_uint8(x) - assert d * 8 == self.d + assert d == self.code_size self.add_c(n, swig_ptr(x)) def replacement_add_with_ids(self, x, ids): n, d = x.shape x = _check_dtype_uint8(x) ids = np.ascontiguousarray(ids, dtype='int64') - assert d * 8 == self.d + assert d == self.code_size assert ids.shape == (n, ), 'not same nb of vectors as ids' self.add_with_ids_c(n, swig_ptr(x), swig_ptr(ids)) def replacement_train(self, x): n, d = x.shape x = _check_dtype_uint8(x) - assert d * 8 == self.d + assert d == self.code_size self.train_c(n, swig_ptr(x)) def replacement_reconstruct(self, key): - x = np.empty(self.d // 8, dtype=np.uint8) + x = np.empty(self.code_size, dtype=np.uint8) self.reconstruct_c(key, swig_ptr(x)) return x + def replacement_reconstruct_n(self, n0=0, ni=-1, x=None): + if ni == -1: + ni = self.ntotal - n0 + if x is None: + x = np.empty((ni, self.code_size), dtype=np.uint8) + else: + assert x.shape == (ni, self.code_size) + + self.reconstruct_n_c(n0, ni, swig_ptr(x)) + return x + def replacement_search(self, x, k): x = _check_dtype_uint8(x) n, d = x.shape - assert d * 8 == self.d + assert d == self.code_size assert k > 0 distances = np.empty((n, k), dtype=np.int32) labels = np.empty((n, k), dtype=np.int64) @@ -804,7 +815,7 @@ def replacement_search(self, x, k): def replacement_search_preassigned(self, x, k, Iq, Dq): n, d = x.shape x = _check_dtype_uint8(x) - assert d * 8 == self.d + assert d == self.code_size assert k > 0 D = np.empty((n, k), dtype=np.int32) @@ -829,7 +840,7 @@ def replacement_search_preassigned(self, x, k, Iq, Dq): def replacement_range_search(self, x, thresh): n, d = x.shape x = _check_dtype_uint8(x) - assert d * 8 == self.d + assert d == self.code_size res = RangeSearchResult(n) self.range_search_c(n, swig_ptr(x), thresh, res) # get pointers and copy them @@ -842,7 +853,7 @@ def replacement_range_search(self, x, thresh): def replacement_range_search_preassigned(self, x, thresh, Iq, Dq, *, params=None): n, d = x.shape x = _check_dtype_uint8(x) - assert d * 8 == self.d + assert d == self.code_size Iq = np.ascontiguousarray(Iq, dtype='int64') assert params is None, "params not supported" @@ -866,9 +877,6 @@ def replacement_range_search_preassigned(self, x, thresh, Iq, Dq, *, params=None I = rev_swig_ptr(res.labels, nd).copy() return lims, D, I - - - def replacement_remove_ids(self, x): if isinstance(x, IDSelector): sel = x @@ -884,6 +892,7 @@ def replacement_remove_ids(self, x): replace_method(the_class, 'search', replacement_search) replace_method(the_class, 'range_search', replacement_range_search) replace_method(the_class, 'reconstruct', replacement_reconstruct) + replace_method(the_class, 'reconstruct_n', replacement_reconstruct_n) replace_method(the_class, 'remove_ids', replacement_remove_ids) replace_method(the_class, 'search_preassigned', replacement_search_preassigned, ignore_missing=True) diff --git a/faiss/python/gpu_wrappers.py b/faiss/python/gpu_wrappers.py index ba0b5fe9f6..6e788511d2 100644 --- a/faiss/python/gpu_wrappers.py +++ b/faiss/python/gpu_wrappers.py @@ -29,8 +29,10 @@ def index_cpu_to_gpu_multiple_py(resources, index, co=None, gpus=None): for i, res in zip(gpus, resources): vdev.push_back(i) vres.push_back(res) - index = index_cpu_to_gpu_multiple(vres, vdev, index, co) - return index + if isinstance(index, IndexBinary): + return index_binary_cpu_to_gpu_multiple(vres, vdev, index, co) + else: + return index_cpu_to_gpu_multiple(vres, vdev, index, co) def index_cpu_to_all_gpus(index, co=None, ngpu=-1): diff --git a/faiss/python/swigfaiss.swig b/faiss/python/swigfaiss.swig index 4b187e5991..7ebc6624e5 100644 --- a/faiss/python/swigfaiss.swig +++ b/faiss/python/swigfaiss.swig @@ -124,6 +124,7 @@ typedef uint64_t size_t; #include #include + #include #include #include @@ -251,6 +252,7 @@ namespace std { %template(ClusteringIterationStatsVector) std::vector; %template(ParameterRangeVector) std::vector; + #ifndef SWIGWIN %template(OnDiskOneListVector) std::vector; #endif // !SWIGWIN @@ -385,6 +387,9 @@ void gpu_sync_all_devices() // order matters because includes are not recursive %include +%template(CombinerRangeKNNfloat) faiss::CombinerRangeKNN; +%template(CombinerRangeKNNint16) faiss::CombinerRangeKNN; + %include %include %include @@ -580,6 +585,7 @@ void gpu_sync_all_devices() %newobject read_VectorTransform; %newobject read_ProductQuantizer; %newobject clone_index; +%newobject clone_binary_index; %newobject clone_Quantizer; %newobject clone_VectorTransform; diff --git a/faiss/utils/utils.cpp b/faiss/utils/utils.cpp index 5f3e5eb964..2eee790c20 100644 --- a/faiss/utils/utils.cpp +++ b/faiss/utils/utils.cpp @@ -556,7 +556,8 @@ bool check_openmp() { namespace { -int64_t count_lt(int64_t n, const float* row, float threshold) { +template +int64_t count_lt(int64_t n, const T* row, T threshold) { for (int64_t i = 0; i < n; i++) { if (!(row[i] < threshold)) { return i; @@ -565,7 +566,8 @@ int64_t count_lt(int64_t n, const float* row, float threshold) { return n; } -int64_t count_gt(int64_t n, const float* row, float threshold) { +template +int64_t count_gt(int64_t n, const T* row, T threshold) { for (int64_t i = 0; i < n; i++) { if (!(row[i] > threshold)) { return i; @@ -576,14 +578,15 @@ int64_t count_gt(int64_t n, const float* row, float threshold) { } // namespace -void CombinerRangeKNN::compute_sizes(int64_t* L_res) { +template +void CombinerRangeKNN::compute_sizes(int64_t* L_res) { this->L_res = L_res; L_res[0] = 0; int64_t j = 0; for (int64_t i = 0; i < nq; i++) { int64_t n_in; if (!mask || !mask[i]) { - const float* row = D + i * k; + const T* row = D + i * k; n_in = keep_max ? count_gt(k, row, r2) : count_lt(k, row, r2); } else { n_in = lim_remain[j + 1] - lim_remain[j]; @@ -597,12 +600,13 @@ void CombinerRangeKNN::compute_sizes(int64_t* L_res) { } } -void CombinerRangeKNN::write_result(float* D_res, int64_t* I_res) { +template +void CombinerRangeKNN::write_result(T* D_res, int64_t* I_res) { FAISS_THROW_IF_NOT(L_res); int64_t j = 0; for (int64_t i = 0; i < nq; i++) { int64_t n_in = L_res[i + 1] - L_res[i]; - float* D_row = D_res + L_res[i]; + T* D_row = D_res + L_res[i]; int64_t* I_row = I_res + L_res[i]; if (!mask || !mask[i]) { memcpy(D_row, D + i * k, n_in * sizeof(*D_row)); @@ -615,4 +619,8 @@ void CombinerRangeKNN::write_result(float* D_res, int64_t* I_res) { } } +// explicit template instantiations +template struct CombinerRangeKNN; +template struct CombinerRangeKNN; + } // namespace faiss diff --git a/faiss/utils/utils.h b/faiss/utils/utils.h index 373869b58f..8578be9447 100644 --- a/faiss/utils/utils.h +++ b/faiss/utils/utils.h @@ -178,25 +178,26 @@ bool check_openmp(); /** This class is used to combine range and knn search results * in contrib.exhaustive_search.range_search_gpu */ +template struct CombinerRangeKNN { int64_t nq; /// nb of queries size_t k; /// number of neighbors for the knn search part - float r2; /// range search radius + T r2; /// range search radius bool keep_max; /// whether to keep max values instead of min. - CombinerRangeKNN(int64_t nq, size_t k, float r2, bool keep_max) + CombinerRangeKNN(int64_t nq, size_t k, T r2, bool keep_max) : nq(nq), k(k), r2(r2), keep_max(keep_max) {} /// Knn search results const int64_t* I = nullptr; /// size nq * k - const float* D = nullptr; /// size nq * k + const T* D = nullptr; /// size nq * k /// optional: range search results (ignored if mask is NULL) const bool* mask = nullptr; /// mask for where knn results are valid, size nq // range search results for remaining entries nrange = sum(mask) const int64_t* lim_remain = nullptr; /// size nrange + 1 - const float* D_remain = nullptr; /// size lim_remain[nrange] + const T* D_remain = nullptr; /// size lim_remain[nrange] const int64_t* I_remain = nullptr; /// size lim_remain[nrange] const int64_t* L_res = nullptr; /// size nq + 1 @@ -205,7 +206,7 @@ struct CombinerRangeKNN { /// Phase 2: caller allocates D_res and I_res (size L_res[nq]) /// Phase 3: fill in D_res and I_res - void write_result(float* D_res, int64_t* I_res); + void write_result(T* D_res, int64_t* I_res); }; } // namespace faiss diff --git a/tests/test_contrib.py b/tests/test_contrib.py index 61a5023ddd..528c779c15 100644 --- a/tests/test_contrib.py +++ b/tests/test_contrib.py @@ -457,6 +457,36 @@ def test_L2(self): def test_IP(self): self.do_test(faiss.METRIC_INNER_PRODUCT) + def test_binary(self): + ds = datasets.SyntheticDataset(64, 1000, 1000, 200) + tobinary = faiss.index_factory(ds.d, "LSHrt") + tobinary.train(ds.get_train()) + index = faiss.IndexBinaryFlat(ds.d) + xb = tobinary.sa_encode(ds.get_database()) + xq = tobinary.sa_encode(ds.get_queries()) + index.add(xb) + + # find a reasonable radius + D, _ = index.search(xq, 10) + radius0 = int(np.median(D[:, -1])) + + # baseline = search with that radius + lims_ref, Dref, Iref = index.range_search(xq, radius0) + + # now see if using just the total number of results, we can get back + # the same result table + query_iterator = exponential_query_iterator(xq) + + radius1, lims_new, Dnew, Inew = range_search_max_results( + index, query_iterator, ds.d // 2, + min_results=Dref.size, clip_to_min=True + ) + + evaluation.check_ref_range_results( + lims_ref, Dref, Iref, + lims_new, Dnew, Inew + ) + class TestClustering(unittest.TestCase): diff --git a/tests/test_index_binary.py b/tests/test_index_binary.py index 19ffed4d09..312530ad46 100644 --- a/tests/test_index_binary.py +++ b/tests/test_index_binary.py @@ -6,7 +6,6 @@ """this is a basic test script for simple indices work""" import os -import sys import numpy as np import unittest import faiss @@ -374,6 +373,7 @@ def test_replicas(self): sub_idx = faiss.IndexBinaryFlat(d) sub_idx.add(xb) index.addIndex(sub_idx) + self.assertEqual(index_ref.code_size, index.code_size) D, I = index.search(xq, 10)