Skip to content

Commit

Permalink
remove a layer of unneeded filters; relying on compiler de-virtualiza…
Browse files Browse the repository at this point in the history
…tion

Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
  • Loading branch information
alexanderguzhva committed Aug 21, 2024
1 parent a164d64 commit 924049a
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 172 deletions.
4 changes: 2 additions & 2 deletions include/knowhere/bitsetview_idselector.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

namespace knowhere {

struct BitsetViewIDSelector : faiss::IDSelector {
BitsetView bitset_view;
struct BitsetViewIDSelector final : faiss::IDSelector {
const BitsetView bitset_view;

inline BitsetViewIDSelector(BitsetView bitset_view) : bitset_view{bitset_view} {
}
Expand Down
35 changes: 0 additions & 35 deletions src/index/hnsw/impl/BitsetFilter.h

This file was deleted.

50 changes: 40 additions & 10 deletions src/index/hnsw/impl/IndexBruteForceWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,37 @@
#include <algorithm>
#include <memory>

#include "index/hnsw/impl/BitsetFilter.h"
#include "knowhere/bitsetview.h"
#include "knowhere/bitsetview_idselector.h"

namespace knowhere {

using idx_t = faiss::idx_t;

// the following structure is a hack, because GCC cannot properly
// de-virtualize a plain BitsetViewIDSelector.
struct BitsetViewIDSelectorWrapper final {
const BitsetView bitset_view;

inline BitsetViewIDSelectorWrapper(BitsetView bitset_view) : bitset_view{bitset_view} {
}

[[nodiscard]] inline bool
is_member(faiss::idx_t id) const {
// it is by design that bitset_view.empty() is not tested here
return (!bitset_view.test(id));
}
};

//
IndexBruteForceWrapper::IndexBruteForceWrapper(faiss::Index* underlying_index)
: faiss::cppcontrib::knowhere::IndexWrapper{underlying_index} {
}

void
IndexBruteForceWrapper::search(faiss::idx_t n, const float* __restrict x, faiss::idx_t k, float* __restrict distances,
faiss::idx_t* __restrict labels, const faiss::SearchParameters* params) const {
faiss::idx_t* __restrict labels,
const faiss::SearchParameters* __restrict params) const {
FAISS_THROW_IF_NOT(k > 0);

std::unique_ptr<faiss::DistanceComputer> dis(index->get_distance_computer());
Expand All @@ -47,8 +62,8 @@ IndexBruteForceWrapper::search(faiss::idx_t n, const float* __restrict x, faiss:
dis->set_query(x + i * index->d);

// allocate heap
idx_t* const local_ids = labels + i * index->d;
float* const local_distances = distances + i * index->d;
idx_t* const __restrict local_ids = labels + i * index->d;
float* const __restrict local_distances = distances + i * index->d;

// set up a filter
faiss::IDSelector* sel = (params == nullptr) ? nullptr : params->sel;
Expand All @@ -59,20 +74,35 @@ IndexBruteForceWrapper::search(faiss::idx_t n, const float* __restrict x, faiss:
}

// try knowhere-specific filter
const knowhere::BitsetViewIDSelector* bw_idselector = dynamic_cast<const knowhere::BitsetViewIDSelector*>(sel);
const knowhere::BitsetViewIDSelector* __restrict bw_idselector =
dynamic_cast<const knowhere::BitsetViewIDSelector*>(sel);

knowhere::BitsetFilter filter(bw_idselector->bitset_view);
BitsetViewIDSelectorWrapper bw_idselector_w(bw_idselector->bitset_view);

if (is_similarity_metric(index->metric_type)) {
using C = faiss::CMin<float, idx_t>;

faiss::cppcontrib::knowhere::brute_force_search_impl<C, faiss::DistanceComputer, knowhere::BitsetFilter>(
index->ntotal, *dis, filter, k, local_distances, local_ids);
if (bw_idselector == nullptr || bw_idselector->bitset_view.empty()) {
faiss::IDSelectorAll sel_all;
faiss::cppcontrib::knowhere::brute_force_search_impl<C, faiss::DistanceComputer, faiss::IDSelectorAll>(
index->ntotal, *dis, sel_all, k, local_distances, local_ids);
} else {
faiss::cppcontrib::knowhere::brute_force_search_impl<C, faiss::DistanceComputer,
BitsetViewIDSelectorWrapper>(
index->ntotal, *dis, bw_idselector_w, k, local_distances, local_ids);
}
} else {
using C = faiss::CMax<float, idx_t>;

faiss::cppcontrib::knowhere::brute_force_search_impl<C, faiss::DistanceComputer, knowhere::BitsetFilter>(
index->ntotal, *dis, filter, k, local_distances, local_ids);
if (bw_idselector == nullptr || bw_idselector->bitset_view.empty()) {
faiss::IDSelectorAll sel_all;
faiss::cppcontrib::knowhere::brute_force_search_impl<C, faiss::DistanceComputer, faiss::IDSelectorAll>(
index->ntotal, *dis, sel_all, k, local_distances, local_ids);
} else {
faiss::cppcontrib::knowhere::brute_force_search_impl<C, faiss::DistanceComputer,
BitsetViewIDSelectorWrapper>(
index->ntotal, *dis, bw_idselector_w, k, local_distances, local_ids);
}
}
}
}
Expand Down
48 changes: 25 additions & 23 deletions src/index/hnsw/impl/IndexHNSWWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include <faiss/IndexHNSW.h>
#include <faiss/MetricType.h>
#include <faiss/cppcontrib/knowhere/impl/Bruteforce.h>
#include <faiss/cppcontrib/knowhere/impl/Filters.h>
#include <faiss/cppcontrib/knowhere/impl/HnswSearcher.h>
#include <faiss/cppcontrib/knowhere/utils/Bitset.h>
#include <faiss/impl/AuxIndexStructures.h>
Expand All @@ -27,7 +26,6 @@
#include <cstdint>
#include <memory>

#include "index/hnsw/impl/BitsetFilter.h"
#include "index/hnsw/impl/FederVisitor.h"
#include "knowhere/bitsetview.h"
#include "knowhere/bitsetview_idselector.h"
Expand Down Expand Up @@ -83,7 +81,7 @@ IndexHNSWWrapper::IndexHNSWWrapper(faiss::IndexHNSW* underlying_index)

void
IndexHNSWWrapper::search(idx_t n, const float* __restrict x, idx_t k, float* __restrict distances,
idx_t* __restrict labels, const faiss::SearchParameters* params_in) const {
idx_t* __restrict labels, const faiss::SearchParameters* __restrict params_in) const {
FAISS_THROW_IF_NOT(k > 0);

const faiss::IndexHNSW* index_hnsw = dynamic_cast<const faiss::IndexHNSW*>(index);
Expand Down Expand Up @@ -150,61 +148,65 @@ IndexHNSWWrapper::search(idx_t n, const float* __restrict x, idx_t k, float* __r
faiss::IDSelector* sel = (params == nullptr) ? nullptr : params->sel;

// try knowhere-specific filter
const knowhere::BitsetViewIDSelector* bw_idselector = dynamic_cast<const knowhere::BitsetViewIDSelector*>(sel);
const knowhere::BitsetViewIDSelector* __restrict bw_idselector =
dynamic_cast<const knowhere::BitsetViewIDSelector*>(sel);

if (bw_idselector == nullptr) {
if (bw_idselector == nullptr || bw_idselector->bitset_view.empty()) {
// no filter
faiss::cppcontrib::knowhere::AllowAllFilter filter;
faiss::IDSelectorAll sel_all;

// feder templating is important, bcz it removes an unneeded 'CALL' instruction.
if (feder == nullptr) {
// no feder
DummyVisitor graph_visitor;

using searcher_type =
faiss::cppcontrib::knowhere::v2_hnsw_searcher<faiss::DistanceComputer, DummyVisitor,
faiss::cppcontrib::knowhere::Bitset,
faiss::cppcontrib::knowhere::AllowAllFilter>;
using searcher_type = faiss::cppcontrib::knowhere::v2_hnsw_searcher<
faiss::DistanceComputer, DummyVisitor, faiss::cppcontrib::knowhere::Bitset, faiss::IDSelectorAll>;

searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, filter, kAlpha, params};
searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes,
sel_all, kAlpha, params};

local_stats = searcher.search(k, distances + i * k, labels + i * k);
} else {
// use feder
FederVisitor graph_visitor(feder);

using searcher_type =
faiss::cppcontrib::knowhere::v2_hnsw_searcher<faiss::DistanceComputer, FederVisitor,
faiss::cppcontrib::knowhere::Bitset,
faiss::cppcontrib::knowhere::AllowAllFilter>;
using searcher_type = faiss::cppcontrib::knowhere::v2_hnsw_searcher<
faiss::DistanceComputer, FederVisitor, faiss::cppcontrib::knowhere::Bitset, faiss::IDSelectorAll>;

searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, filter, kAlpha, params};
searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes,
sel_all, kAlpha, params};

local_stats = searcher.search(k, distances + i * k, labels + i * k);
}
} else {
// with filter
knowhere::BitsetFilter filter(bw_idselector->bitset_view);

// feder templating is important, bcz it removes an unneeded 'CALL' instruction.
if (feder == nullptr) {
// no feder
DummyVisitor graph_visitor;

using searcher_type = faiss::cppcontrib::knowhere::v2_hnsw_searcher<
faiss::DistanceComputer, DummyVisitor, faiss::cppcontrib::knowhere::Bitset, knowhere::BitsetFilter>;
using searcher_type =
faiss::cppcontrib::knowhere::v2_hnsw_searcher<faiss::DistanceComputer, DummyVisitor,
faiss::cppcontrib::knowhere::Bitset,
knowhere::BitsetViewIDSelector>;

searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, filter, kAlpha, params};
searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes,
*bw_idselector, kAlpha, params};

local_stats = searcher.search(k, distances + i * k, labels + i * k);
} else {
// use feder
FederVisitor graph_visitor(feder);

using searcher_type = faiss::cppcontrib::knowhere::v2_hnsw_searcher<
faiss::DistanceComputer, FederVisitor, faiss::cppcontrib::knowhere::Bitset, knowhere::BitsetFilter>;
using searcher_type =
faiss::cppcontrib::knowhere::v2_hnsw_searcher<faiss::DistanceComputer, FederVisitor,
faiss::cppcontrib::knowhere::Bitset,
knowhere::BitsetViewIDSelector>;

searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, filter, kAlpha, params};
searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes,
*bw_idselector, kAlpha, params};

local_stats = searcher.search(k, distances + i * k, labels + i * k);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include <faiss/impl/IDSelector.h>

#include <faiss/cppcontrib/knowhere/impl/Bruteforce.h>
#include <faiss/cppcontrib/knowhere/impl/Filters.h>

namespace faiss {
namespace cppcontrib {
Expand All @@ -37,7 +36,7 @@ void IndexBruteForceWrapper::search(
idx_t k,
float* __restrict distances,
idx_t* __restrict labels,
const SearchParameters* params
const SearchParameters* __restrict params
) const {
FAISS_THROW_IF_NOT(k > 0);

Expand All @@ -57,39 +56,60 @@ void IndexBruteForceWrapper::search(
dis->set_query(x + i * index->d);

// allocate heap
idx_t* const local_ids = labels + i * index->d;
float* const local_distances = distances + i * index->d;
idx_t* const __restrict local_ids = labels + i * index->d;
float* const __restrict local_distances = distances + i * index->d;

// set up a filter
IDSelector* sel = (params == nullptr) ? nullptr : params->sel;

// template, just in case a filter type will be specialized
// in order to remove virtual function call overhead.
using filter_type = DefaultIDSelectorFilter<IDSelector>;
filter_type filter(sel);
IDSelector* __restrict sel = (params == nullptr) ? nullptr : params->sel;

if (is_similarity_metric(index->metric_type)) {
using C = CMin<float, idx_t>;

brute_force_search_impl<C, DistanceComputer, filter_type>(
index->ntotal,
*dis,
filter,
k,
local_distances,
local_ids
);
if (sel == nullptr) {
// Compiler is expected to de-virtualize virtual method calls
IDSelectorAll sel_all;
brute_force_search_impl<C, DistanceComputer, IDSelectorAll>(
index->ntotal,
*dis,
sel_all,
k,
local_distances,
local_ids
);
} else {
brute_force_search_impl<C, DistanceComputer, IDSelector>(
index->ntotal,
*dis,
*sel,
k,
local_distances,
local_ids
);
}
} else {
using C = CMax<float, idx_t>;

brute_force_search_impl<C, DistanceComputer, filter_type>(
index->ntotal,
*dis,
filter,
k,
local_distances,
local_ids
);
if (sel == nullptr) {
// Compiler is expected to de-virtualize virtual method calls
IDSelectorAll sel_all;
brute_force_search_impl<C, DistanceComputer, IDSelectorAll>(
index->ntotal,
*dis,
sel_all,
k,
local_distances,
local_ids
);
} else {
brute_force_search_impl<C, DistanceComputer, IDSelector>(
index->ntotal,
*dis,
*sel,
k,
local_distances,
local_ids
);
}
}
}
}
Expand Down
Loading

0 comments on commit 924049a

Please sign in to comment.