Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: BF supports ids of base_data starting from a specific value #941

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions include/knowhere/bitsetview_idselector.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@ namespace knowhere {

struct BitsetViewIDSelector final : faiss::IDSelector {
const BitsetView bitset_view;
const size_t id_offset;

inline BitsetViewIDSelector(BitsetView bitset_view) : bitset_view{bitset_view} {
inline BitsetViewIDSelector(BitsetView bitset_view, const size_t offset = 0)
: bitset_view{bitset_view}, id_offset(offset) {
}

inline bool
is_member(faiss::idx_t id) const override final {
// it is by design that bitset_view.empty() is not tested here
return (!bitset_view.test(id));
return (!bitset_view.test(id + id_offset));
}
};

Expand Down
1 change: 1 addition & 0 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ constexpr const char* RETAIN_ITERATOR_ORDER = "retain_iterator_order";
constexpr const char* RADIUS = "radius";
constexpr const char* RANGE_FILTER = "range_filter";
constexpr const char* INPUT_IDS = "input_ids";
constexpr const char* INPUT_BEG_ID = "input_begin_id";
constexpr const char* OUTPUT_TENSOR = "output_tensor";
constexpr const char* DEVICE_ID = "gpu_id";
constexpr const char* NUM_BUILD_THREAD = "num_build_thread";
Expand Down
20 changes: 19 additions & 1 deletion include/knowhere/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ class DataSet : public std::enable_shared_from_this<const DataSet> {
this->data_[meta::DIM] = Var(std::in_place_index<4>, dim);
}

void
SetTensorBeginId(const int64_t offset) {
std::unique_lock lock(mutex_);
this->data_[meta::INPUT_BEG_ID] = Var(std::in_place_index<4>, offset);
}

void
SetJsonInfo(const std::string& info) {
std::unique_lock lock(mutex_);
Expand Down Expand Up @@ -260,6 +266,17 @@ class DataSet : public std::enable_shared_from_this<const DataSet> {
this->is_sparse = is_sparse;
}

int64_t
GetTensorBeginId() const {
std::shared_lock lock(mutex_);
auto it = this->data_.find(meta::INPUT_BEG_ID);
if (it != this->data_.end()) {
int64_t res = *std::get_if<4>(&it->second);
return res;
}
return 0;
}

// deprecated API
template <typename T>
void
Expand Down Expand Up @@ -288,12 +305,13 @@ class DataSet : public std::enable_shared_from_this<const DataSet> {
using DataSetPtr = std::shared_ptr<DataSet>;

inline DataSetPtr
GenDataSet(const int64_t nb, const int64_t dim, const void* xb) {
GenDataSet(const int64_t nb, const int64_t dim, const void* xb, const int64_t beg_id = 0) {
auto ret_ds = std::make_shared<DataSet>();
ret_ds->SetRows(nb);
ret_ds->SetDim(dim);
ret_ds->SetTensor(xb);
ret_ds->SetIsOwner(false);
ret_ds->SetTensorBeginId(beg_id);
return ret_ds;
}

Expand Down
43 changes: 34 additions & 9 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
namespace knowhere {

/* knowhere wrapper API to call faiss brute force search for all metric types */
/* If the ids of base_dataset does not start from 0, the BF functions will filter based on the real ids and return the
* real ids.*/

class BruteForceConfig : public BaseConfig {};

Expand Down Expand Up @@ -70,6 +72,7 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset

auto xb = base->GetTensor();
auto nb = base->GetRows();
auto xb_id_offset = base->GetTensorBeginId();
auto dim = base->GetDim();

auto xq = query->GetTensor();
Expand Down Expand Up @@ -121,7 +124,7 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
auto cur_labels = labels_ptr + topk * index;
auto cur_distances = distances_ptr + topk * index;

BitsetViewIDSelector bw_idselector(bitset);
BitsetViewIDSelector bw_idselector(bitset, xb_id_offset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

switch (faiss_metric_type) {
Expand Down Expand Up @@ -179,6 +182,11 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
if (xb_id_offset != 0) {
for (auto i = 0; i < nq * topk; i++) {
labels[i] = labels[i] == -1 ? -1 : labels[i] + xb_id_offset;
}
}
auto res = GenResultDataSet(nq, cfg.k.value(), std::move(labels), std::move(distances));

#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
Expand All @@ -202,6 +210,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
auto xb = base->GetTensor();
auto nb = base->GetRows();
auto dim = base->GetDim();
auto xb_id_offset = base->GetTensorBeginId();

auto xq = query->GetTensor();
auto nq = query->GetRows();
Expand Down Expand Up @@ -248,7 +257,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
auto cur_labels = labels + topk * index;
auto cur_distances = distances + topk * index;

BitsetViewIDSelector bw_idselector(bitset);
BitsetViewIDSelector bw_idselector(bitset, xb_id_offset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

switch (faiss_metric_type) {
Expand Down Expand Up @@ -311,6 +320,11 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
}
// LCOV_EXCL_STOP
#endif
if (xb_id_offset != 0) {
for (auto i = 0; i < nq * topk; i++) {
labels[i] = labels[i] == -1 ? -1 : labels[i] + xb_id_offset;
}
}

return Status::success;
}
Expand All @@ -331,6 +345,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
auto xb = base->GetTensor();
auto nb = base->GetRows();
auto dim = base->GetDim();
auto xb_id_offset = base->GetTensorBeginId();

auto xq = query->GetTensor();
auto nq = query->GetRows();
Expand Down Expand Up @@ -423,7 +438,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
ThreadPool::ScopedSearchOmpSetter setter(1);
faiss::RangeSearchResult res(1);

BitsetViewIDSelector bw_idselector(bitset);
BitsetViewIDSelector bw_idselector(bitset, xb_id_offset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

switch (faiss_metric_type) {
Expand Down Expand Up @@ -469,7 +484,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
result_id_array[index].resize(elem_cnt);
for (size_t j = 0; j < elem_cnt; j++) {
result_dist_array[index][j] = res.distances[j];
result_id_array[index][j] = res.labels[j];
result_id_array[index][j] = res.labels[j] + xb_id_offset;
}
if (cfg.range_filter.value() != defaultRangeFilter) {
FilterRangeSearchResultForOneNq(result_dist_array[index], result_id_array[index], is_ip, radius,
Expand Down Expand Up @@ -504,6 +519,7 @@ BruteForce::SearchSparseWithBuf(const DataSetPtr base_dataset, const DataSetPtr
auto base = static_cast<const sparse::SparseRow<float>*>(base_dataset->GetTensor());
auto rows = base_dataset->GetRows();
auto dim = base_dataset->GetDim();
auto xb_id_offset = base_dataset->GetTensorBeginId();

auto xq = static_cast<const sparse::SparseRow<float>*>(query_dataset->GetTensor());
auto nq = query_dataset->GetRows();
Expand Down Expand Up @@ -561,7 +577,8 @@ BruteForce::SearchSparseWithBuf(const DataSetPtr base_dataset, const DataSetPtr
}
sparse::MaxMinHeap<float> heap(topk);
for (int64_t j = 0; j < rows; ++j) {
if (!bitset.empty() && bitset.test(j)) {
auto x_id = j + xb_id_offset;
if (!bitset.empty() && bitset.test(x_id)) {
continue;
}
float row_sum = 0;
Expand All @@ -573,7 +590,7 @@ BruteForce::SearchSparseWithBuf(const DataSetPtr base_dataset, const DataSetPtr
}
float dist = row.dot(base[j], computer, row_sum);
if (dist > 0) {
heap.push(j, dist);
heap.push(x_id, dist);
}
}
int result_size = heap.size();
Expand Down Expand Up @@ -626,6 +643,7 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da
auto xb = base->GetTensor();
auto nb = base->GetRows();
auto dim = base->GetDim();
auto xb_id_offset = base->GetTensorBeginId();

auto xq = query->GetTensor();
auto nq = query->GetRows();
Expand Down Expand Up @@ -669,7 +687,7 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da
futs.emplace_back(pool->push([&, index = i] {
ThreadPool::ScopedSearchOmpSetter setter(1);

BitsetViewIDSelector bw_idselector(bitset);
BitsetViewIDSelector bw_idselector(bitset, xb_id_offset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;
auto larger_is_closer = faiss::is_similarity_metric(faiss_metric_type) || is_cosine;
auto max_dis = larger_is_closer ? std::numeric_limits<float>::lowest() : std::numeric_limits<float>::max();
Expand Down Expand Up @@ -697,6 +715,11 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da
return Status::invalid_metric_type;
}
}
if (xb_id_offset != 0) {
for (auto& distances_id : distances_ids) {
distances_id.id = distances_id.id == -1 ? -1 : distances_id.id + xb_id_offset;
}
}
vec[index] = std::make_shared<PrecomputedDistanceIterator>(std::move(distances_ids), larger_is_closer);

return Status::success;
Expand Down Expand Up @@ -726,6 +749,7 @@ BruteForce::AnnIterator<knowhere::sparse::SparseRow<float>>(const DataSetPtr bas
auto base = static_cast<const sparse::SparseRow<float>*>(base_dataset->GetTensor());
auto rows = base_dataset->GetRows();
auto dim = base_dataset->GetDim();
auto xb_id_offset = base_dataset->GetTensorBeginId();

auto xq = static_cast<const sparse::SparseRow<float>*>(query_dataset->GetTensor());
auto nq = query_dataset->GetRows();
Expand Down Expand Up @@ -776,7 +800,8 @@ BruteForce::AnnIterator<knowhere::sparse::SparseRow<float>>(const DataSetPtr bas
std::vector<DistId> distances_ids;
if (row.size() > 0) {
for (int64_t j = 0; j < rows; ++j) {
if (!bitset.empty() && bitset.test(j)) {
auto xb_id = j + xb_id_offset;
if (!bitset.empty() && bitset.test(xb_id)) {
continue;
}
float row_sum = 0;
Expand All @@ -788,7 +813,7 @@ BruteForce::AnnIterator<knowhere::sparse::SparseRow<float>>(const DataSetPtr bas
}
auto dist = row.dot(base[j], computer, row_sum);
if (dist > 0) {
distances_ids.emplace_back(j, dist);
distances_ids.emplace_back(xb_id, dist);
}
}
}
Expand Down
47 changes: 47 additions & 0 deletions tests/ut/test_bruteforce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "catch2/catch_approx.hpp"
#include "catch2/catch_test_macros.hpp"
#include "catch2/generators/catch_generators.hpp"
#include "faiss/utils/Heap.h"
#include "knowhere/comp/brute_force.h"
#include "knowhere/comp/index_param.h"
#include "knowhere/utils.h"
Expand Down Expand Up @@ -196,3 +197,49 @@ TEST_CASE("Test Brute Force", "[binary vector]") {
}
}
}

TEST_CASE("Test Brute Force with input ids", "[float vector]") {
using Catch::Approx;
const int64_t nb = 1000;
const int64_t nq = 1;
const int64_t dim = 128;
const int64_t k = 10;
const knowhere::Json conf = {
{knowhere::meta::DIM, dim},
{knowhere::meta::METRIC_TYPE, "L2"},
{knowhere::meta::TOPK, k},
};
std::vector<int64_t> block_prefix = {0, 333, 500, 555, 1000};

// generate filter id and data
auto filter_bits = GenerateBitsetWithRandomTbitsSet(nb, 100);
knowhere::BitsetView bitset(filter_bits.data(), nb);

const auto total_train_ds = GenDataSet(nb, dim);
const auto query_ds = GenDataSet(nq, dim);

std::vector<float> dis(nq * k, std::numeric_limits<float>::quiet_NaN());
std::vector<int64_t> ids(nq * k, -1);
faiss::float_maxheap_array_t heaps{nq, k, ids.data(), dis.data()};
heaps.heapify();
for (auto i = 0; i < block_prefix.size() - 1; i++) {
auto begin_id = block_prefix[i];
auto end_id = block_prefix[i + 1];
auto blk_rows = end_id - begin_id;
auto tensor = (const float*)total_train_ds->GetTensor() + dim * begin_id;
auto blk_train_ds = knowhere::GenDataSet(blk_rows, dim, tensor, begin_id);
auto partial_v = knowhere::BruteForce::Search<knowhere::fp32>(blk_train_ds, query_ds, conf, bitset);
REQUIRE(partial_v.has_value());
auto partial_res = partial_v.value();
heaps.addn_with_ids(k, partial_res->GetDistance(), partial_res->GetIds(), k, 0, nq);
}
heaps.reorder();

auto gt = knowhere::BruteForce::Search<knowhere::fp32>(total_train_ds, query_ds, conf, bitset);
auto gt_ids = gt.value()->GetIds();
auto gt_dis = gt.value()->GetDistance();
for (auto i = 0; i < nq * k; i++) {
REQUIRE(gt_ids[i] == ids[i]);
REQUIRE(gt_dis[i] == dis[i]);
}
}
Loading