From d6f77e977e2c9203506eb373421ef2fc65fd48b2 Mon Sep 17 00:00:00 2001 From: cqy123456 Date: Wed, 13 Nov 2024 03:13:47 -0500 Subject: [PATCH] enhance: BF supports ids of base_data starting from a specific value Signed-off-by: cqy123456 --- include/knowhere/bitsetview_idselector.h | 6 ++- include/knowhere/comp/index_param.h | 1 + include/knowhere/dataset.h | 20 +++++++++- src/common/comp/brute_force.cc | 43 +++++++++++++++++----- tests/ut/test_bruteforce.cc | 47 ++++++++++++++++++++++++ 5 files changed, 105 insertions(+), 12 deletions(-) diff --git a/include/knowhere/bitsetview_idselector.h b/include/knowhere/bitsetview_idselector.h index c09aba222..39f6ff1a8 100644 --- a/include/knowhere/bitsetview_idselector.h +++ b/include/knowhere/bitsetview_idselector.h @@ -19,14 +19,16 @@ namespace knowhere { struct BitsetViewIDSelector final : faiss::IDSelector { const BitsetView bitset_view; + const size_t id_offset; - inline BitsetViewIDSelector(BitsetView bitset_view) : bitset_view{bitset_view} { + inline BitsetViewIDSelector(BitsetView bitset_view, const size_t offset = 0) + : bitset_view{bitset_view}, id_offset(offset) { } inline bool is_member(faiss::idx_t id) const override final { // it is by design that bitset_view.empty() is not tested here - return (!bitset_view.test(id)); + return (!bitset_view.test(id + id_offset)); } }; diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index 4051962c1..1452fc2d9 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -85,6 +85,7 @@ constexpr const char* RETAIN_ITERATOR_ORDER = "retain_iterator_order"; constexpr const char* RADIUS = "radius"; constexpr const char* RANGE_FILTER = "range_filter"; constexpr const char* INPUT_IDS = "input_ids"; +constexpr const char* INPUT_BEG_ID = "input_begin_id"; constexpr const char* OUTPUT_TENSOR = "output_tensor"; constexpr const char* DEVICE_ID = "gpu_id"; constexpr const char* NUM_BUILD_THREAD = "num_build_thread"; diff --git a/include/knowhere/dataset.h b/include/knowhere/dataset.h index 1a7b37169..6bc6e61cb 100644 --- a/include/knowhere/dataset.h +++ b/include/knowhere/dataset.h @@ -142,6 +142,12 @@ class DataSet : public std::enable_shared_from_this { this->data_[meta::DIM] = Var(std::in_place_index<4>, dim); } + void + SetTensorBeginId(const int64_t offset) { + std::unique_lock lock(mutex_); + this->data_[meta::INPUT_BEG_ID] = Var(std::in_place_index<4>, offset); + } + void SetJsonInfo(const std::string& info) { std::unique_lock lock(mutex_); @@ -260,6 +266,17 @@ class DataSet : public std::enable_shared_from_this { this->is_sparse = is_sparse; } + int64_t + GetTensorBeginId() const { + std::shared_lock lock(mutex_); + auto it = this->data_.find(meta::INPUT_BEG_ID); + if (it != this->data_.end()) { + int64_t res = *std::get_if<4>(&it->second); + return res; + } + return 0; + } + // deprecated API template void @@ -288,12 +305,13 @@ class DataSet : public std::enable_shared_from_this { using DataSetPtr = std::shared_ptr; inline DataSetPtr -GenDataSet(const int64_t nb, const int64_t dim, const void* xb) { +GenDataSet(const int64_t nb, const int64_t dim, const void* xb, const int64_t beg_id = 0) { auto ret_ds = std::make_shared(); ret_ds->SetRows(nb); ret_ds->SetDim(dim); ret_ds->SetTensor(xb); ret_ds->SetIsOwner(false); + ret_ds->SetTensorBeginId(beg_id); return ret_ds; } diff --git a/src/common/comp/brute_force.cc b/src/common/comp/brute_force.cc index 647792620..cbeafb89a 100644 --- a/src/common/comp/brute_force.cc +++ b/src/common/comp/brute_force.cc @@ -34,6 +34,8 @@ namespace knowhere { /* knowhere wrapper API to call faiss brute force search for all metric types */ +/* If the ids of base_dataset does not start from 0, the BF functions will filter based on the real ids and return the + * real ids.*/ class BruteForceConfig : public BaseConfig {}; @@ -70,6 +72,7 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset auto xb = base->GetTensor(); auto nb = base->GetRows(); + auto xb_id_offset = base->GetTensorBeginId(); auto dim = base->GetDim(); auto xq = query->GetTensor(); @@ -121,7 +124,7 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset auto cur_labels = labels_ptr + topk * index; auto cur_distances = distances_ptr + topk * index; - BitsetViewIDSelector bw_idselector(bitset); + BitsetViewIDSelector bw_idselector(bitset, xb_id_offset); faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; switch (faiss_metric_type) { @@ -179,6 +182,11 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset if (ret != Status::success) { return expected::Err(ret, "failed to brute force search"); } + if (xb_id_offset != 0) { + for (auto i = 0; i < nq * topk; i++) { + labels[i] = labels[i] == -1 ? -1 : labels[i] + xb_id_offset; + } + } auto res = GenResultDataSet(nq, cfg.k.value(), std::move(labels), std::move(distances)); #if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) @@ -202,6 +210,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ auto xb = base->GetTensor(); auto nb = base->GetRows(); auto dim = base->GetDim(); + auto xb_id_offset = base->GetTensorBeginId(); auto xq = query->GetTensor(); auto nq = query->GetRows(); @@ -248,7 +257,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ auto cur_labels = labels + topk * index; auto cur_distances = distances + topk * index; - BitsetViewIDSelector bw_idselector(bitset); + BitsetViewIDSelector bw_idselector(bitset, xb_id_offset); faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; switch (faiss_metric_type) { @@ -311,6 +320,11 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ } // LCOV_EXCL_STOP #endif + if (xb_id_offset != 0) { + for (auto i = 0; i < nq * topk; i++) { + labels[i] = labels[i] == -1 ? -1 : labels[i] + xb_id_offset; + } + } return Status::success; } @@ -331,6 +345,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da auto xb = base->GetTensor(); auto nb = base->GetRows(); auto dim = base->GetDim(); + auto xb_id_offset = base->GetTensorBeginId(); auto xq = query->GetTensor(); auto nq = query->GetRows(); @@ -423,7 +438,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da ThreadPool::ScopedSearchOmpSetter setter(1); faiss::RangeSearchResult res(1); - BitsetViewIDSelector bw_idselector(bitset); + BitsetViewIDSelector bw_idselector(bitset, xb_id_offset); faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; switch (faiss_metric_type) { @@ -469,7 +484,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da result_id_array[index].resize(elem_cnt); for (size_t j = 0; j < elem_cnt; j++) { result_dist_array[index][j] = res.distances[j]; - result_id_array[index][j] = res.labels[j]; + result_id_array[index][j] = res.labels[j] + xb_id_offset; } if (cfg.range_filter.value() != defaultRangeFilter) { FilterRangeSearchResultForOneNq(result_dist_array[index], result_id_array[index], is_ip, radius, @@ -504,6 +519,7 @@ BruteForce::SearchSparseWithBuf(const DataSetPtr base_dataset, const DataSetPtr auto base = static_cast*>(base_dataset->GetTensor()); auto rows = base_dataset->GetRows(); auto dim = base_dataset->GetDim(); + auto xb_id_offset = base_dataset->GetTensorBeginId(); auto xq = static_cast*>(query_dataset->GetTensor()); auto nq = query_dataset->GetRows(); @@ -561,7 +577,8 @@ BruteForce::SearchSparseWithBuf(const DataSetPtr base_dataset, const DataSetPtr } sparse::MaxMinHeap heap(topk); for (int64_t j = 0; j < rows; ++j) { - if (!bitset.empty() && bitset.test(j)) { + auto x_id = j + xb_id_offset; + if (!bitset.empty() && bitset.test(x_id)) { continue; } float row_sum = 0; @@ -573,7 +590,7 @@ BruteForce::SearchSparseWithBuf(const DataSetPtr base_dataset, const DataSetPtr } float dist = row.dot(base[j], computer, row_sum); if (dist > 0) { - heap.push(j, dist); + heap.push(x_id, dist); } } int result_size = heap.size(); @@ -626,6 +643,7 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da auto xb = base->GetTensor(); auto nb = base->GetRows(); auto dim = base->GetDim(); + auto xb_id_offset = base->GetTensorBeginId(); auto xq = query->GetTensor(); auto nq = query->GetRows(); @@ -669,7 +687,7 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da futs.emplace_back(pool->push([&, index = i] { ThreadPool::ScopedSearchOmpSetter setter(1); - BitsetViewIDSelector bw_idselector(bitset); + BitsetViewIDSelector bw_idselector(bitset, xb_id_offset); faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; auto larger_is_closer = faiss::is_similarity_metric(faiss_metric_type) || is_cosine; auto max_dis = larger_is_closer ? std::numeric_limits::lowest() : std::numeric_limits::max(); @@ -697,6 +715,11 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da return Status::invalid_metric_type; } } + if (xb_id_offset != 0) { + for (auto& distances_id : distances_ids) { + distances_id.id = distances_id.id == -1 ? -1 : distances_id.id + xb_id_offset; + } + } vec[index] = std::make_shared(std::move(distances_ids), larger_is_closer); return Status::success; @@ -726,6 +749,7 @@ BruteForce::AnnIterator>(const DataSetPtr bas auto base = static_cast*>(base_dataset->GetTensor()); auto rows = base_dataset->GetRows(); auto dim = base_dataset->GetDim(); + auto xb_id_offset = base_dataset->GetTensorBeginId(); auto xq = static_cast*>(query_dataset->GetTensor()); auto nq = query_dataset->GetRows(); @@ -776,7 +800,8 @@ BruteForce::AnnIterator>(const DataSetPtr bas std::vector distances_ids; if (row.size() > 0) { for (int64_t j = 0; j < rows; ++j) { - if (!bitset.empty() && bitset.test(j)) { + auto xb_id = j + xb_id_offset; + if (!bitset.empty() && bitset.test(xb_id)) { continue; } float row_sum = 0; @@ -788,7 +813,7 @@ BruteForce::AnnIterator>(const DataSetPtr bas } auto dist = row.dot(base[j], computer, row_sum); if (dist > 0) { - distances_ids.emplace_back(j, dist); + distances_ids.emplace_back(xb_id, dist); } } } diff --git a/tests/ut/test_bruteforce.cc b/tests/ut/test_bruteforce.cc index 04c338ce7..a8124fff7 100644 --- a/tests/ut/test_bruteforce.cc +++ b/tests/ut/test_bruteforce.cc @@ -12,6 +12,7 @@ #include "catch2/catch_approx.hpp" #include "catch2/catch_test_macros.hpp" #include "catch2/generators/catch_generators.hpp" +#include "faiss/utils/Heap.h" #include "knowhere/comp/brute_force.h" #include "knowhere/comp/index_param.h" #include "knowhere/utils.h" @@ -196,3 +197,49 @@ TEST_CASE("Test Brute Force", "[binary vector]") { } } } + +TEST_CASE("Test Brute Force with input ids", "[float vector]") { + using Catch::Approx; + const int64_t nb = 1000; + const int64_t nq = 1; + const int64_t dim = 128; + const int64_t k = 10; + const knowhere::Json conf = { + {knowhere::meta::DIM, dim}, + {knowhere::meta::METRIC_TYPE, "L2"}, + {knowhere::meta::TOPK, k}, + }; + std::vector block_prefix = {0, 333, 500, 555, 1000}; + + // generate filter id and data + auto filter_bits = GenerateBitsetWithRandomTbitsSet(nb, 100); + knowhere::BitsetView bitset(filter_bits.data(), nb); + + const auto total_train_ds = GenDataSet(nb, dim); + const auto query_ds = GenDataSet(nq, dim); + + std::vector dis(nq * k, std::numeric_limits::quiet_NaN()); + std::vector ids(nq * k, -1); + faiss::float_maxheap_array_t heaps{nq, k, ids.data(), dis.data()}; + heaps.heapify(); + for (auto i = 0; i < block_prefix.size() - 1; i++) { + auto begin_id = block_prefix[i]; + auto end_id = block_prefix[i + 1]; + auto blk_rows = end_id - begin_id; + auto tensor = (const float*)total_train_ds->GetTensor() + dim * begin_id; + auto blk_train_ds = knowhere::GenDataSet(blk_rows, dim, tensor, begin_id); + auto partial_v = knowhere::BruteForce::Search(blk_train_ds, query_ds, conf, bitset); + REQUIRE(partial_v.has_value()); + auto partial_res = partial_v.value(); + heaps.addn_with_ids(k, partial_res->GetDistance(), partial_res->GetIds(), k, 0, nq); + } + heaps.reorder(); + + auto gt = knowhere::BruteForce::Search(total_train_ds, query_ds, conf, bitset); + auto gt_ids = gt.value()->GetIds(); + auto gt_dis = gt.value()->GetDistance(); + for (auto i = 0; i < nq * k; i++) { + REQUIRE(gt_ids[i] == ids[i]); + REQUIRE(gt_dis[i] == dis[i]); + } +}