diff --git a/include/knowhere/bitsetview_idselector.h b/include/knowhere/bitsetview_idselector.h index 163e58cca..c09aba222 100644 --- a/include/knowhere/bitsetview_idselector.h +++ b/include/knowhere/bitsetview_idselector.h @@ -17,8 +17,8 @@ namespace knowhere { -struct BitsetViewIDSelector : faiss::IDSelector { - BitsetView bitset_view; +struct BitsetViewIDSelector final : faiss::IDSelector { + const BitsetView bitset_view; inline BitsetViewIDSelector(BitsetView bitset_view) : bitset_view{bitset_view} { } diff --git a/src/index/hnsw/impl/BitsetFilter.h b/src/index/hnsw/impl/BitsetFilter.h deleted file mode 100644 index 7b3e5b88b..000000000 --- a/src/index/hnsw/impl/BitsetFilter.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (C) 2019-2024 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include - -#include "knowhere/bitsetview.h" - -namespace knowhere { - -// specialized override for knowhere -struct BitsetFilter { - // contains disabled nodes. - knowhere::BitsetView bitset_view; - - inline BitsetFilter(knowhere::BitsetView bitset_view_) : bitset_view{bitset_view_} { - } - - inline bool - allowed(const faiss::idx_t idx) const { - // there's no check for bitset_view.empty() by design - return !bitset_view.test(idx); - } -}; - -} // namespace knowhere diff --git a/src/index/hnsw/impl/IndexBruteForceWrapper.cc b/src/index/hnsw/impl/IndexBruteForceWrapper.cc index ffd3f74cf..937c357e9 100644 --- a/src/index/hnsw/impl/IndexBruteForceWrapper.cc +++ b/src/index/hnsw/impl/IndexBruteForceWrapper.cc @@ -21,7 +21,6 @@ #include #include -#include "index/hnsw/impl/BitsetFilter.h" #include "knowhere/bitsetview.h" #include "knowhere/bitsetview_idselector.h" @@ -29,6 +28,21 @@ namespace knowhere { using idx_t = faiss::idx_t; +// the following structure is a hack, because GCC cannot properly +// de-virtualize a plain BitsetViewIDSelector. +struct BitsetViewIDSelectorWrapper final { + const BitsetView bitset_view; + + inline BitsetViewIDSelectorWrapper(BitsetView bitset_view) : bitset_view{bitset_view} { + } + + inline bool + is_member(faiss::idx_t id) const { + // it is by design that bitset_view.empty() is not tested here + return (!bitset_view.test(id)); + } +}; + // IndexBruteForceWrapper::IndexBruteForceWrapper(faiss::Index* underlying_index) : faiss::cppcontrib::knowhere::IndexWrapper{underlying_index} { @@ -36,7 +50,8 @@ IndexBruteForceWrapper::IndexBruteForceWrapper(faiss::Index* underlying_index) void IndexBruteForceWrapper::search(faiss::idx_t n, const float* __restrict x, faiss::idx_t k, float* __restrict distances, - faiss::idx_t* __restrict labels, const faiss::SearchParameters* params) const { + faiss::idx_t* __restrict labels, + const faiss::SearchParameters* __restrict params) const { FAISS_THROW_IF_NOT(k > 0); std::unique_ptr dis(index->get_distance_computer()); @@ -47,8 +62,8 @@ IndexBruteForceWrapper::search(faiss::idx_t n, const float* __restrict x, faiss: dis->set_query(x + i * index->d); // allocate heap - idx_t* const local_ids = labels + i * index->d; - float* const local_distances = distances + i * index->d; + idx_t* const __restrict local_ids = labels + i * index->d; + float* const __restrict local_distances = distances + i * index->d; // set up a filter faiss::IDSelector* sel = (params == nullptr) ? nullptr : params->sel; @@ -59,20 +74,35 @@ IndexBruteForceWrapper::search(faiss::idx_t n, const float* __restrict x, faiss: } // try knowhere-specific filter - const knowhere::BitsetViewIDSelector* bw_idselector = dynamic_cast(sel); + const knowhere::BitsetViewIDSelector* __restrict bw_idselector = + dynamic_cast(sel); - knowhere::BitsetFilter filter(bw_idselector->bitset_view); + BitsetViewIDSelectorWrapper bw_idselector_w(bw_idselector->bitset_view); if (is_similarity_metric(index->metric_type)) { using C = faiss::CMin; - faiss::cppcontrib::knowhere::brute_force_search_impl( - index->ntotal, *dis, filter, k, local_distances, local_ids); + if (bw_idselector == nullptr || bw_idselector->bitset_view.empty()) { + faiss::IDSelectorAll sel_all; + faiss::cppcontrib::knowhere::brute_force_search_impl( + index->ntotal, *dis, sel_all, k, local_distances, local_ids); + } else { + faiss::cppcontrib::knowhere::brute_force_search_impl( + index->ntotal, *dis, bw_idselector_w, k, local_distances, local_ids); + } } else { using C = faiss::CMax; - faiss::cppcontrib::knowhere::brute_force_search_impl( - index->ntotal, *dis, filter, k, local_distances, local_ids); + if (bw_idselector == nullptr || bw_idselector->bitset_view.empty()) { + faiss::IDSelectorAll sel_all; + faiss::cppcontrib::knowhere::brute_force_search_impl( + index->ntotal, *dis, sel_all, k, local_distances, local_ids); + } else { + faiss::cppcontrib::knowhere::brute_force_search_impl( + index->ntotal, *dis, bw_idselector_w, k, local_distances, local_ids); + } } } } diff --git a/src/index/hnsw/impl/IndexHNSWWrapper.cc b/src/index/hnsw/impl/IndexHNSWWrapper.cc index b8d22a31a..bfe101099 100644 --- a/src/index/hnsw/impl/IndexHNSWWrapper.cc +++ b/src/index/hnsw/impl/IndexHNSWWrapper.cc @@ -14,7 +14,6 @@ #include #include #include -#include #include #include #include @@ -27,7 +26,6 @@ #include #include -#include "index/hnsw/impl/BitsetFilter.h" #include "index/hnsw/impl/FederVisitor.h" #include "knowhere/bitsetview.h" #include "knowhere/bitsetview_idselector.h" @@ -83,7 +81,7 @@ IndexHNSWWrapper::IndexHNSWWrapper(faiss::IndexHNSW* underlying_index) void IndexHNSWWrapper::search(idx_t n, const float* __restrict x, idx_t k, float* __restrict distances, - idx_t* __restrict labels, const faiss::SearchParameters* params_in) const { + idx_t* __restrict labels, const faiss::SearchParameters* __restrict params_in) const { FAISS_THROW_IF_NOT(k > 0); const faiss::IndexHNSW* index_hnsw = dynamic_cast(index); @@ -150,61 +148,65 @@ IndexHNSWWrapper::search(idx_t n, const float* __restrict x, idx_t k, float* __r faiss::IDSelector* sel = (params == nullptr) ? nullptr : params->sel; // try knowhere-specific filter - const knowhere::BitsetViewIDSelector* bw_idselector = dynamic_cast(sel); + const knowhere::BitsetViewIDSelector* __restrict bw_idselector = + dynamic_cast(sel); - if (bw_idselector == nullptr) { + if (bw_idselector == nullptr || bw_idselector->bitset_view.empty()) { // no filter - faiss::cppcontrib::knowhere::AllowAllFilter filter; + faiss::IDSelectorAll sel_all; // feder templating is important, bcz it removes an unneeded 'CALL' instruction. if (feder == nullptr) { // no feder DummyVisitor graph_visitor; - using searcher_type = - faiss::cppcontrib::knowhere::v2_hnsw_searcher; + using searcher_type = faiss::cppcontrib::knowhere::v2_hnsw_searcher< + faiss::DistanceComputer, DummyVisitor, faiss::cppcontrib::knowhere::Bitset, faiss::IDSelectorAll>; - searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, filter, kAlpha, params}; + searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, + sel_all, kAlpha, params}; local_stats = searcher.search(k, distances + i * k, labels + i * k); } else { // use feder FederVisitor graph_visitor(feder); - using searcher_type = - faiss::cppcontrib::knowhere::v2_hnsw_searcher; + using searcher_type = faiss::cppcontrib::knowhere::v2_hnsw_searcher< + faiss::DistanceComputer, FederVisitor, faiss::cppcontrib::knowhere::Bitset, faiss::IDSelectorAll>; - searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, filter, kAlpha, params}; + searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, + sel_all, kAlpha, params}; local_stats = searcher.search(k, distances + i * k, labels + i * k); } } else { // with filter - knowhere::BitsetFilter filter(bw_idselector->bitset_view); // feder templating is important, bcz it removes an unneeded 'CALL' instruction. if (feder == nullptr) { // no feder DummyVisitor graph_visitor; - using searcher_type = faiss::cppcontrib::knowhere::v2_hnsw_searcher< - faiss::DistanceComputer, DummyVisitor, faiss::cppcontrib::knowhere::Bitset, knowhere::BitsetFilter>; + using searcher_type = + faiss::cppcontrib::knowhere::v2_hnsw_searcher; - searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, filter, kAlpha, params}; + searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, + *bw_idselector, kAlpha, params}; local_stats = searcher.search(k, distances + i * k, labels + i * k); } else { // use feder FederVisitor graph_visitor(feder); - using searcher_type = faiss::cppcontrib::knowhere::v2_hnsw_searcher< - faiss::DistanceComputer, FederVisitor, faiss::cppcontrib::knowhere::Bitset, knowhere::BitsetFilter>; + using searcher_type = + faiss::cppcontrib::knowhere::v2_hnsw_searcher; - searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, filter, kAlpha, params}; + searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, + *bw_idselector, kAlpha, params}; local_stats = searcher.search(k, distances + i * k, labels + i * k); } diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.cpp b/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.cpp index 4c6cad23d..68b6d4d5e 100644 --- a/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.cpp +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.cpp @@ -22,7 +22,6 @@ #include #include -#include namespace faiss { namespace cppcontrib { @@ -37,7 +36,7 @@ void IndexBruteForceWrapper::search( idx_t k, float* __restrict distances, idx_t* __restrict labels, - const SearchParameters* params + const SearchParameters* __restrict params ) const { FAISS_THROW_IF_NOT(k > 0); @@ -57,39 +56,60 @@ void IndexBruteForceWrapper::search( dis->set_query(x + i * index->d); // allocate heap - idx_t* const local_ids = labels + i * index->d; - float* const local_distances = distances + i * index->d; + idx_t* const __restrict local_ids = labels + i * index->d; + float* const __restrict local_distances = distances + i * index->d; // set up a filter - IDSelector* sel = (params == nullptr) ? nullptr : params->sel; - - // template, just in case a filter type will be specialized - // in order to remove virtual function call overhead. - using filter_type = DefaultIDSelectorFilter; - filter_type filter(sel); + IDSelector* __restrict sel = (params == nullptr) ? nullptr : params->sel; if (is_similarity_metric(index->metric_type)) { using C = CMin; - brute_force_search_impl( - index->ntotal, - *dis, - filter, - k, - local_distances, - local_ids - ); + if (sel == nullptr) { + // Compiler is expected to de-virtualize virtual method calls + IDSelectorAll sel_all; + brute_force_search_impl( + index->ntotal, + *dis, + sel_all, + k, + local_distances, + local_ids + ); + } else { + brute_force_search_impl( + index->ntotal, + *dis, + *sel, + k, + local_distances, + local_ids + ); + } } else { using C = CMax; - brute_force_search_impl( - index->ntotal, - *dis, - filter, - k, - local_distances, - local_ids - ); + if (sel == nullptr) { + // Compiler is expected to de-virtualize virtual method calls + IDSelectorAll sel_all; + brute_force_search_impl( + index->ntotal, + *dis, + sel_all, + k, + local_distances, + local_ids + ); + } else { + brute_force_search_impl( + index->ntotal, + *dis, + *sel, + k, + local_distances, + local_ids + ); + } } } } diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexHNSWWrapper.cpp b/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexHNSWWrapper.cpp index 6d56383e7..dcc20a4c8 100644 --- a/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexHNSWWrapper.cpp +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexHNSWWrapper.cpp @@ -21,7 +21,6 @@ #include #include -#include #include #include @@ -156,29 +155,49 @@ void IndexHNSWWrapper::search( // set up a filter IDSelector* sel = (params == nullptr) ? nullptr : params->sel; - - // template, just in case a filter type will be specialized - // in order to remove virtual function call overhead. - using filter_type = DefaultIDSelectorFilter; - filter_type filter(sel); - - using searcher_type = v2_hnsw_searcher< - DistanceComputer, - DummyVisitor, - Bitset, - filter_type>; - - searcher_type searcher{ - hnsw, - *(dis.get()), - graph_visitor, - bitset_visited_nodes, - filter, - kAlpha, - params - }; - - local_stats = searcher.search(k, distances + i * k, labels + i * k); + if (sel == nullptr) { + // no filter. + // It it expected that a compile will be able to + // de-virtualize the class. + IDSelectorAll sel_all; + + using searcher_type = v2_hnsw_searcher< + DistanceComputer, + DummyVisitor, + Bitset, + IDSelectorAll>; + + searcher_type searcher{ + hnsw, + *(dis.get()), + graph_visitor, + bitset_visited_nodes, + sel_all, + kAlpha, + params + }; + + local_stats = searcher.search(k, distances + i * k, labels + i * k); + } else { + // there is a filter + using searcher_type = v2_hnsw_searcher< + DistanceComputer, + DummyVisitor, + Bitset, + IDSelector>; + + searcher_type searcher{ + hnsw, + *(dis.get()), + graph_visitor, + bitset_visited_nodes, + *sel, + kAlpha, + params + }; + + local_stats = searcher.search(k, distances + i * k, labels + i * k); + } // update stats if possible if (hnsw_stats != nullptr) { diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Bruteforce.h b/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Bruteforce.h index f13a35740..10073811a 100644 --- a/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Bruteforce.h +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Bruteforce.h @@ -28,8 +28,8 @@ namespace knowhere { template void brute_force_search_impl( const idx_t ntotal, - DistanceComputerT& qdis, - const FilterT filter, + DistanceComputerT& __restrict qdis, + const FilterT& __restrict filter, const idx_t k, float* __restrict distances, idx_t* __restrict labels @@ -40,7 +40,7 @@ void brute_force_search_impl( auto max_heap = std::make_unique[]>(k); idx_t n_added = 0; for (idx_t idx = 0; idx < ntotal; ++idx) { - if (filter.allowed(idx)) { + if (filter.is_member(idx)) { const float distance = qdis(idx); if (n_added < k) { n_added += 1; diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Filters.h b/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Filters.h deleted file mode 100644 index fce3a673b..000000000 --- a/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Filters.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (C) 2019-2024 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include - -namespace faiss { -namespace cppcontrib { -namespace knowhere { - -// a filter that selects according to an IDSelector -template -struct DefaultIDSelectorFilter { - // contains enabled nodes. - // a non-owning pointer. - const IDSelectorT* selector = nullptr; - - inline DefaultIDSelectorFilter(const IDSelectorT* selector_) : selector{selector_} {} - - inline bool allowed(const idx_t idx) const { - return ((selector == nullptr) || (selector->is_member(idx))); - } -}; - -// a filter that allows everything -struct AllowAllFilter { - constexpr inline bool allowed(const idx_t) const { - return true; - } -}; - - -} -} -} diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/HnswSearcher.h b/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/HnswSearcher.h index 29501bfab..3bc624afe 100644 --- a/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/HnswSearcher.h +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/HnswSearcher.h @@ -72,7 +72,7 @@ struct v2_hnsw_searcher { // a filter for disabled nodes. // the reference is not owned. - const FilterT filter; + const FilterT& filter; // parameter for the filtering const float kAlpha; @@ -205,7 +205,7 @@ struct v2_hnsw_searcher { // is the node disabled? int status = knowhere::Neighbor::kValid; - if (!filter.allowed(v1)) { + if (!filter.is_member(v1)) { // yes, disabled status = knowhere::Neighbor::kInvalid; @@ -381,7 +381,7 @@ struct v2_hnsw_searcher { // initialize retset with a single 'nearest' point { - if (!filter.allowed(nearest)) { + if (!filter.is_member(nearest)) { retset.insert(knowhere::Neighbor(nearest, d_nearest, knowhere::Neighbor::kInvalid)); } else { retset.insert(knowhere::Neighbor(nearest, d_nearest, knowhere::Neighbor::kValid)); diff --git a/thirdparty/faiss/faiss/impl/IDSelector.h b/thirdparty/faiss/faiss/impl/IDSelector.h index 4d02540a2..2966881ea 100644 --- a/thirdparty/faiss/faiss/impl/IDSelector.h +++ b/thirdparty/faiss/faiss/impl/IDSelector.h @@ -125,7 +125,7 @@ struct IDSelectorNot : IDSelector { /// selects all entries (useful for benchmarking) struct IDSelectorAll : IDSelector { - bool is_member(idx_t id) const final { + inline bool is_member(idx_t id) const final { return true; } virtual ~IDSelectorAll() {} diff --git a/thirdparty/faiss/faiss/utils/distances.cpp b/thirdparty/faiss/faiss/utils/distances.cpp index 6f4186443..1584d72e1 100644 --- a/thirdparty/faiss/faiss/utils/distances.cpp +++ b/thirdparty/faiss/faiss/utils/distances.cpp @@ -135,7 +135,7 @@ namespace { // may be useful if the lion's share of samples are filtered out. struct IDSelectorAll { - inline bool is_member(const size_t idx) const { + constexpr inline bool is_member(const size_t idx) const { return true; } };