From 4e5cba3a087f8983091d4467edb26278b6ef0cbf Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Tue, 30 Jul 2024 15:08:51 -0400 Subject: [PATCH 01/21] basic FAISS_HNSW_FLAT Signed-off-by: Alexandr Guzhva --- cmake/libs/libfaiss.cmake | 3 +- include/knowhere/comp/index_param.h | 2 + include/knowhere/index/index_table.h | 2 + src/index/hnsw/faiss_hnsw.cc | 630 ++++++++++++++++++ src/index/hnsw/faiss_hnsw_config.h | 136 ++++ src/index/hnsw/impl/BitsetFilter.h | 35 + src/index/hnsw/impl/FederVisitor.h | 49 ++ src/index/hnsw/impl/IndexBruteForceWrapper.cc | 80 +++ src/index/hnsw/impl/IndexBruteForceWrapper.h | 29 + src/index/hnsw/impl/IndexHNSWWrapper.cc | 241 +++++++ src/index/hnsw/impl/IndexHNSWWrapper.h | 51 ++ src/index/hnsw/impl/IndexWrapperCosine.cc | 29 + src/index/hnsw/impl/IndexWrapperCosine.h | 32 + .../faiss/benchs/bench_hnsw_knowhere.cpp | 206 ++++++ thirdparty/faiss/faiss/IndexCosine.cpp | 319 +++++++++ thirdparty/faiss/faiss/IndexCosine.h | 110 +++ .../knowhere/IndexBruteForceWrapper.cpp | 103 +++ .../knowhere/IndexBruteForceWrapper.h | 39 ++ .../cppcontrib/knowhere/IndexHNSWWrapper.cpp | 212 ++++++ .../cppcontrib/knowhere/IndexHNSWWrapper.h | 56 ++ .../cppcontrib/knowhere/IndexWrapper.cpp | 65 ++ .../faiss/cppcontrib/knowhere/IndexWrapper.h | 51 ++ .../cppcontrib/knowhere/impl/Bruteforce.h | 73 ++ .../faiss/cppcontrib/knowhere/impl/Filters.h | 44 ++ .../cppcontrib/knowhere/impl/HnswSearcher.h | 415 ++++++++++++ .../faiss/cppcontrib/knowhere/impl/Neighbor.h | 229 +++++++ .../faiss/cppcontrib/knowhere/utils/Bitset.h | 115 ++++ thirdparty/faiss/faiss/impl/index_read.cpp | 19 +- thirdparty/faiss/faiss/impl/index_write.cpp | 10 + 29 files changed, 3383 insertions(+), 2 deletions(-) create mode 100644 src/index/hnsw/faiss_hnsw.cc create mode 100644 src/index/hnsw/faiss_hnsw_config.h create mode 100644 src/index/hnsw/impl/BitsetFilter.h create mode 100644 src/index/hnsw/impl/FederVisitor.h create mode 100644 src/index/hnsw/impl/IndexBruteForceWrapper.cc create mode 100644 src/index/hnsw/impl/IndexBruteForceWrapper.h create mode 100644 src/index/hnsw/impl/IndexHNSWWrapper.cc create mode 100644 src/index/hnsw/impl/IndexHNSWWrapper.h create mode 100644 src/index/hnsw/impl/IndexWrapperCosine.cc create mode 100644 src/index/hnsw/impl/IndexWrapperCosine.h create mode 100644 thirdparty/faiss/benchs/bench_hnsw_knowhere.cpp create mode 100644 thirdparty/faiss/faiss/IndexCosine.cpp create mode 100644 thirdparty/faiss/faiss/IndexCosine.h create mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.cpp create mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.h create mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/IndexHNSWWrapper.cpp create mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/IndexHNSWWrapper.h create mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/IndexWrapper.cpp create mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/IndexWrapper.h create mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Bruteforce.h create mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Filters.h create mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/impl/HnswSearcher.h create mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Neighbor.h create mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/utils/Bitset.h diff --git a/cmake/libs/libfaiss.cmake b/cmake/libs/libfaiss.cmake index 7200eb177..f4022bb4a 100644 --- a/cmake/libs/libfaiss.cmake +++ b/cmake/libs/libfaiss.cmake @@ -1,7 +1,8 @@ knowhere_file_glob( GLOB FAISS_SRCS thirdparty/faiss/faiss/*.cpp thirdparty/faiss/faiss/impl/*.cpp thirdparty/faiss/faiss/invlists/*.cpp - thirdparty/faiss/faiss/utils/*.cpp) + thirdparty/faiss/faiss/utils/*.cpp + thirdparty/faiss/faiss/cppcontrib/knowhere/*.cpp) knowhere_file_glob(GLOB FAISS_AVX512_SRCS thirdparty/faiss/faiss/impl/*avx512.cpp) diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index 504aeeef0..612a6feb9 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -53,6 +53,8 @@ constexpr const char* INDEX_HNSW_SQ8 = "HNSW_SQ8"; constexpr const char* INDEX_HNSW_SQ8_REFINE = "HNSW_SQ8_REFINE"; constexpr const char* INDEX_DISKANN = "DISKANN"; +constexpr const char* INDEX_FAISS_HNSW_FLAT = "FAISS_HNSW_FLAT"; + constexpr const char* INDEX_SPARSE_INVERTED_INDEX = "SPARSE_INVERTED_INDEX"; constexpr const char* INDEX_SPARSE_WAND = "SPARSE_WAND"; } // namespace IndexEnum diff --git a/include/knowhere/index/index_table.h b/include/knowhere/index/index_table.h index 34a646b63..3c29973ba 100644 --- a/include/knowhere/index/index_table.h +++ b/include/knowhere/index/index_table.h @@ -66,6 +66,8 @@ static std::set> legal_knowhere_index = { {IndexEnum::INDEX_HNSW_SQ8_REFINE, VecType::VECTOR_FLOAT}, {IndexEnum::INDEX_HNSW_SQ8_REFINE, VecType::VECTOR_FLOAT16}, {IndexEnum::INDEX_HNSW_SQ8_REFINE, VecType::VECTOR_BFLOAT16}, + // faiss hnsw + {IndexEnum::INDEX_FAISS_HNSW_FLAT, VecType::VECTOR_FLOAT}, // diskann {IndexEnum::INDEX_DISKANN, VecType::VECTOR_FLOAT}, {IndexEnum::INDEX_DISKANN, VecType::VECTOR_FLOAT16}, diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc new file mode 100644 index 000000000..c42269ff7 --- /dev/null +++ b/src/index/hnsw/faiss_hnsw.cc @@ -0,0 +1,630 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "common/metric.h" +#include "faiss/IndexBinaryHNSW.h" +#include "faiss/IndexCosine.h" +#include "faiss/IndexHNSW.h" +#include "faiss/IndexRefine.h" +#include "faiss/impl/ScalarQuantizer.h" +#include "faiss/index_io.h" +#include "index/hnsw/faiss_hnsw_config.h" +#include "index/hnsw/impl/IndexBruteForceWrapper.h" +#include "index/hnsw/impl/IndexHNSWWrapper.h" +#include "index/hnsw/impl/IndexWrapperCosine.h" +#include "io/memory_io.h" +#include "knowhere/bitsetview_idselector.h" +#include "knowhere/comp/index_param.h" +#include "knowhere/comp/thread_pool.h" +#include "knowhere/comp/time_recorder.h" +#include "knowhere/config.h" +#include "knowhere/expected.h" +#include "knowhere/index/index_factory.h" +#include "knowhere/index/index_node_data_mock_wrapper.h" +#include "knowhere/log.h" +#include "knowhere/range_util.h" +#include "knowhere/utils.h" + +#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) +#include "knowhere/prometheus_client.h" +#endif + +namespace knowhere { + +// +class BaseFaissIndexNode : public IndexNode { + public: + BaseFaissIndexNode(const int32_t& /*version*/, const Object& object) { + build_pool = ThreadPool::GetGlobalBuildThreadPool(); + search_pool = ThreadPool::GetGlobalSearchThreadPool(); + } + + // + Status + Train(const DataSetPtr dataset, const Config& cfg) override { + // config + const BaseConfig& base_cfg = static_cast(cfg); + + // use build_pool_ to make sure the OMP threads spawned by index_->train etc + // can inherit the low nice value of threads in build_pool_. + auto tryObj = build_pool + ->push([&] { + std::unique_ptr setter; + if (base_cfg.num_build_thread.has_value()) { + setter = + std::make_unique(base_cfg.num_build_thread.value()); + } else { + setter = std::make_unique(); + } + + return TrainInternal(dataset, cfg); + }) + .getTry(); + + if (!tryObj.hasValue()) { + LOG_KNOWHERE_WARNING_ << "faiss internal error: " << tryObj.exception().what(); + return Status::faiss_inner_error; + } + + return tryObj.value(); + } + + Status + Add(const DataSetPtr dataset, const Config& cfg) override { + const BaseConfig& base_cfg = static_cast(cfg); + + // use build_pool_ to make sure the OMP threads spawned by index_->train etc + // can inherit the low nice value of threads in build_pool_. + auto tryObj = build_pool + ->push([&] { + std::unique_ptr setter; + if (base_cfg.num_build_thread.has_value()) { + setter = + std::make_unique(base_cfg.num_build_thread.value()); + } else { + setter = std::make_unique(); + } + + return AddInternal(dataset, cfg); + }) + .getTry(); + + if (!tryObj.hasValue()) { + LOG_KNOWHERE_WARNING_ << "faiss internal error: " << tryObj.exception().what(); + return Status::faiss_inner_error; + } + + return tryObj.value(); + } + + int64_t + Size() const override { + // todo + return 0; + } + + expected + GetIndexMeta(const Config& cfg) const override { + // todo + return expected::Err(Status::not_implemented, "GetIndexMeta not implemented"); + } + + protected: + std::shared_ptr build_pool; + std::shared_ptr search_pool; + + // train impl + virtual Status + TrainInternal(const DataSetPtr dataset, const Config& cfg) = 0; + + // add impl + virtual Status + AddInternal(const DataSetPtr dataset, const Config& cfg) = 0; +}; + +// +class BaseFaissRegularIndexNode : public BaseFaissIndexNode { + public: + BaseFaissRegularIndexNode(const int32_t& version, const Object& object) + : BaseFaissIndexNode(version, object), index{nullptr} { + } + + expected + GetVectorByIds(const DataSetPtr dataset) const override { + if (this->index == nullptr) { + return expected::Err(Status::empty_index, "index not loaded"); + } + if (!this->index->is_trained) { + return expected::Err(Status::index_not_trained, "index not trained"); + } + + auto dim = Dim(); + auto rows = dataset->GetRows(); + auto ids = dataset->GetIds(); + + try { + auto data = std::make_unique(dim * rows); + + for (int64_t i = 0; i < rows; i++) { + const int64_t id = ids[i]; + assert(id >= 0 && id < index->ntotal); + index->reconstruct(id, data.get() + i * dim); + } + + return GenResultDataSet(rows, dim, std::move(data)); + } catch (const std::exception& e) { + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); + return expected::Err(Status::faiss_inner_error, e.what()); + } + } + + Status + Serialize(BinarySet& binset) const override { + if (index == nullptr) { + return Status::empty_index; + } + + try { + MemoryIOWriter writer; + faiss::write_index(index.get(), &writer); + + std::shared_ptr data(writer.data()); + binset.Append(Type(), data, writer.tellg()); + } catch (const std::exception& e) { + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); + return Status::faiss_inner_error; + } + + return Status::success; + } + + Status + Deserialize(const BinarySet& binset, const Config& config) override { + auto binary = binset.GetByName(Type()); + if (binary == nullptr) { + LOG_KNOWHERE_ERROR_ << "Invalid binary set."; + return Status::invalid_binary_set; + } + + MemoryIOReader reader(binary->data.get(), binary->size); + try { + auto read_index = std::unique_ptr(faiss::read_index(&reader)); + index.reset(read_index.release()); + } catch (const std::exception& e) { + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); + return Status::faiss_inner_error; + } + + return Status::success; + } + + Status + DeserializeFromFile(const std::string& filename, const Config& config) override { + auto cfg = static_cast(config); + + int io_flags = 0; + if (cfg.enable_mmap.value()) { + io_flags |= faiss::IO_FLAG_MMAP; + } + + try { + auto read_index = std::unique_ptr(faiss::read_index(filename.data(), io_flags)); + index.reset(read_index.release()); + } catch (const std::exception& e) { + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); + return Status::faiss_inner_error; + } + + return Status::success; + } + + // + int64_t + Dim() const override { + if (index == nullptr) { + return -1; + } + + return index->d; + } + + int64_t + Count() const override { + if (index == nullptr) { + return -1; + } + + // total number of indexed vectors + return index->ntotal; + } + + protected: + std::unique_ptr index; + + Status + AddInternal(const DataSetPtr dataset, const Config&) override { + if (this->index == nullptr) { + LOG_KNOWHERE_ERROR_ << "Can not add data to an empty index."; + return Status::empty_index; + } + + auto data = dataset->GetTensor(); + auto rows = dataset->GetRows(); + try { + this->index->add(rows, reinterpret_cast(data)); + } catch (const std::exception& e) { + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); + return Status::faiss_inner_error; + } + + return Status::success; + } +}; + +// +class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { + public: + BaseFaissRegularIndexHNSWNode(const int32_t& version, const Object& object) + : BaseFaissRegularIndexNode(version, object) { + } + + bool + HasRawData(const std::string& metric_type) const override { + if (index == nullptr) { + return false; + } + + // check whether we use a refined index + const faiss::IndexRefine* const index_refine = dynamic_cast(index.get()); + if (index_refine == nullptr) { + return false; + } + + // check whether the refine index is IndexFlat + // todo: SQfp16 is good for fp16 data type + // todo: SQbf16 is good for bf16 data type + const faiss::IndexFlat* const index_refine_flat = + dynamic_cast(index_refine->refine_index); + if (index_refine_flat == nullptr) { + // we might be using a different refine index + return false; + } + + // yes, we're using IndexRefine with a Flat index + return true; + } + + expected + Search(const DataSetPtr dataset, const Config& cfg, const BitsetView& bitset) const override { + if (this->index == nullptr) { + return expected::Err(Status::empty_index, "index not loaded"); + } + if (!this->index->is_trained) { + return expected::Err(Status::index_not_trained, "index not trained"); + } + + const auto dim = dataset->GetDim(); + const auto rows = dataset->GetRows(); + const auto* data = dataset->GetTensor(); + + const auto hnsw_cfg = static_cast(cfg); + const auto k = hnsw_cfg.k.value(); + const bool is_cosine = IsMetricType(hnsw_cfg.metric_type.value(), knowhere::metric::COSINE); + + const bool whether_bf_search = WhetherPerformBruteForceSearch(hnsw_cfg, bitset); + + feder::hnsw::FederResultUniq feder_result; + if (hnsw_cfg.trace_visit.value()) { + if (rows != 1) { + return expected::Err(Status::invalid_args, "a single query vector is required"); + } + feder_result = std::make_unique(); + } + + auto ids = std::make_unique(rows * k); + auto distances = std::make_unique(rows * k); + try { + std::vector> futs; + futs.reserve(rows); + for (int64_t i = 0; i < rows; ++i) { + futs.emplace_back(search_pool->push([&, idx = i] { + // 1 thread per element + ThreadPool::ScopedOmpSetter setter(1); + + // set up a query + const float* cur_query = (const float*)data + idx * dim; + + // set up local results + faiss::idx_t* const __restrict local_ids = ids.get() + k * idx; + float* const __restrict local_distances = distances.get() + k * idx; + + // set up faiss search parameters + knowhere::SearchParametersHNSWWrapper hnsw_search_params; + if (hnsw_cfg.ef.has_value()) { + hnsw_search_params.efSearch = hnsw_cfg.ef.value(); + } + // do not collect HNSW stats + hnsw_search_params.hnsw_stats = nullptr; + // set up feder + hnsw_search_params.feder = feder_result.get(); + // set up kAlpha + hnsw_search_params.kAlpha = bitset.filter_ratio() * 0.7f; + + // set up a selector + BitsetViewIDSelector bw_idselector(bitset); + faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; + + hnsw_search_params.sel = id_selector; + + // use knowhere-based search by default + const bool override_faiss_search = hnsw_cfg.override_faiss_search.value_or(true); + + // check if we have a refine available. + faiss::IndexRefine* const index_refine = dynamic_cast(index.get()); + + if (index_refine != nullptr) { + // yes, it is possible to refine results. + + // cast a base index to IndexHNSW-based index + faiss::IndexHNSW* const index_hnsw = dynamic_cast(index_refine->base_index); + + if (index_hnsw == nullptr) { + // this is unexpected + throw std::runtime_error("Expecting faiss::IndexHNSW"); + } + + // pick a wrapper for hnsw which does not own indices + knowhere::IndexHNSWWrapper wrapper_hnsw_search(index_hnsw); + knowhere::IndexBruteForceWrapper wrapper_bf(index_hnsw); + + faiss::Index* base_wrapper = nullptr; + if (!override_faiss_search) { + // use the original index, no wrappers + base_wrapper = index_hnsw; + } else if (whether_bf_search) { + // use brute-force wrapper + base_wrapper = &wrapper_bf; + } else { + // use hnsw-search wrapper + base_wrapper = &wrapper_hnsw_search; + } + + // check if used wants a refined result + if (hnsw_cfg.refine_k.has_value()) { + // yes, a user wants to perform a refine + + // set up search parameters + faiss::IndexRefineSearchParameters refine_params; + refine_params.k_factor = hnsw_cfg.refine_k.value(); + // a refine procedure itself does not need to care about filtering + refine_params.sel = nullptr; + refine_params.base_index_params = &hnsw_search_params; + + // is it a cosine index? + if (index_hnsw->storage->is_cosine && is_cosine) { + // yes, wrap both base and refine index + knowhere::IndexWrapperCosine cosine_wrapper( + index_refine->refine_index, + dynamic_cast(index_hnsw)->get_inverse_l2_norms()); + + // create a temporary refine index which does not own + faiss::IndexRefine tmp_refine(base_wrapper, &cosine_wrapper); + + // perform a search + tmp_refine.search(1, cur_query, k, local_distances, local_ids, &refine_params); + } else { + // no, wrap base index only. + + // create a temporary refine index which does not own + faiss::IndexRefine tmp_refine(base_wrapper, index_refine->refine_index); + + // perform a search + tmp_refine.search(1, cur_query, k, local_distances, local_ids, &refine_params); + } + } else { + // no, a user wants to skip a refine + + // perform a search + base_wrapper->search(1, cur_query, k, local_distances, local_ids, &hnsw_search_params); + } + } else { + // there's no refining available + + // check if refine is required + if (hnsw_cfg.refine_k.has_value()) { + // this is not possible, throw an error + throw std::runtime_error("Refine is not provided by the index."); + } + + // cast to IndexHNSW-based index + faiss::IndexHNSW* const index_hnsw = dynamic_cast(index.get()); + + if (index_hnsw == nullptr) { + // this is unexpected + throw std::runtime_error("Expecting faiss::IndexHNSW"); + } + + // pick a wrapper for hnsw which does not own indices + knowhere::IndexHNSWWrapper wrapper_hnsw_search(index_hnsw); + knowhere::IndexBruteForceWrapper wrapper_bf(index_hnsw); + + faiss::Index* wrapper = nullptr; + if (!override_faiss_search) { + // use the original index, no wrappers + wrapper = index_hnsw; + } else if (whether_bf_search) { + // use brute-force wrapper + wrapper = &wrapper_bf; + } else { + // use hnsw-search wrapper + wrapper = &wrapper_hnsw_search; + } + + // perform a search + wrapper->search(1, cur_query, k, local_distances, local_ids, &hnsw_search_params); + } + })); + } + + // wait for the completion + WaitAllSuccess(futs); + } catch (const std::exception& e) { + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); + return expected::Err(Status::faiss_inner_error, e.what()); + } + + auto res = GenResultDataSet(rows, k, std::move(ids), std::move(distances)); + + // set visit_info json string into result dataset + if (feder_result != nullptr) { + Json json_visit_info, json_id_set; + nlohmann::to_json(json_visit_info, feder_result->visit_info_); + nlohmann::to_json(json_id_set, feder_result->id_set_); + res->SetJsonInfo(json_visit_info.dump()); + res->SetJsonIdSet(json_id_set.dump()); + } + + return res; + } + + protected: + // Decides whether a brute force should be used instead of a regular HNSW search. + // This may be applicable in case of very large topk values or + // extremely high filtering levels. + bool + WhetherPerformBruteForceSearch(const BaseConfig& cfg, const BitsetView& bitset) const { + constexpr float kHnswSearchKnnBFFilterThreshold = 0.93f; + constexpr float kHnswSearchRangeBFFilterThreshold = 0.97f; + constexpr float kHnswSearchBFTopkThreshold = 0.5f; + + auto k = cfg.k.value(); + + if (k >= (index->ntotal * kHnswSearchBFTopkThreshold)) { + return true; + } + + if (!bitset.empty()) { + const size_t filtered_out_num = bitset.count(); +#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) + double ratio = ((double)filtered_out_num) / bitset.size(); + knowhere::knowhere_hnsw_bitset_ratio.Observe(ratio); +#endif + if (filtered_out_num >= (index->ntotal * kHnswSearchKnnBFFilterThreshold) || + k >= (index->ntotal - filtered_out_num) * kHnswSearchBFTopkThreshold) { + return true; + } + } + + // the default value + return false; + } + + Status + AddInternal(const DataSetPtr dataset, const Config&) override { + if (this->index == nullptr) { + LOG_KNOWHERE_ERROR_ << "Can not add data to an empty index."; + return Status::empty_index; + } + + auto data = dataset->GetTensor(); + auto rows = dataset->GetRows(); + try { + this->index->add(rows, reinterpret_cast(data)); + } catch (const std::exception& e) { + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); + return Status::faiss_inner_error; + } + + return Status::success; + } +}; + +// +template +class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode { + public: + BaseFaissRegularIndexHNSWFlatNode(const int32_t& version, const Object& object) + : BaseFaissRegularIndexHNSWNode(version, object) { + } + + bool + HasRawData(const std::string& metric_type) const override { + if (index == nullptr) { + return false; + } + + // yes, a flat index has it + return true; + } + + std::unique_ptr + CreateConfig() const override { + return std::make_unique(); + } + + std::string + Type() const override { + return knowhere::IndexEnum::INDEX_FAISS_HNSW_FLAT; + } + + protected: + Status + TrainInternal(const DataSetPtr dataset, const Config& cfg) override { + // number of rows + auto rows = dataset->GetRows(); + // dimensionality of the data + auto dim = dataset->GetDim(); + // data + auto data = dataset->GetTensor(); + + // config + auto hnsw_cfg = static_cast(cfg); + + auto metric = Str2FaissMetricType(hnsw_cfg.metric_type.value()); + if (!metric.has_value()) { + LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << hnsw_cfg.metric_type.value(); + return Status::invalid_metric_type; + } + + // create an index + const bool is_cosine = IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE); + + std::unique_ptr hnsw_index; + if (is_cosine) { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value()); + } else { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); + } + + hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); + + // train + hnsw_index->train(rows, (const float*)data); + + // done + index = std::move(hnsw_index); + return Status::success; + } +}; + +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_FLAT, BaseFaissRegularIndexHNSWFlatNode, fp32); + +} // namespace knowhere diff --git a/src/index/hnsw/faiss_hnsw_config.h b/src/index/hnsw/faiss_hnsw_config.h new file mode 100644 index 000000000..f36450460 --- /dev/null +++ b/src/index/hnsw/faiss_hnsw_config.h @@ -0,0 +1,136 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#ifndef FAISS_HNSW_CONFIG_H +#define FAISS_HNSW_CONFIG_H + +#include "knowhere/comp/index_param.h" +#include "knowhere/config.h" + +namespace knowhere { + +namespace { + +constexpr const CFG_INT::value_type kIteratorSeedEf = 40; +constexpr const CFG_INT::value_type kEfMinValue = 16; +constexpr const CFG_INT::value_type kDefaultRangeSearchEf = 16; + +} // namespace + +class FaissHnswConfig : public BaseConfig { + public: + CFG_INT M; + CFG_INT efConstruction; + CFG_INT ef; + CFG_INT seed_ef; + CFG_INT overview_levels; + + // use a knowhere search rather than a default faiss search + CFG_BOOL override_faiss_search; + // whether an index is built with a refine support + CFG_BOOL refine; + // undefined value leads to a search without a refine + CFG_FLOAT refine_k; + + KNOHWERE_DECLARE_CONFIG(FaissHnswConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(M).description("hnsw M").set_default(30).set_range(2, 2048).for_train(); + KNOWHERE_CONFIG_DECLARE_FIELD(efConstruction) + .description("hnsw efConstruction") + .set_default(360) + .set_range(1, std::numeric_limits::max()) + .for_train(); + KNOWHERE_CONFIG_DECLARE_FIELD(ef) + .description("hnsw ef") + .allow_empty_without_default() + .set_range(1, std::numeric_limits::max()) + .for_search() + .for_range_search(); + KNOWHERE_CONFIG_DECLARE_FIELD(seed_ef) + .description("hnsw seed_ef when using iterator") + .set_default(kIteratorSeedEf) + .set_range(1, std::numeric_limits::max()) + .for_iterator(); + KNOWHERE_CONFIG_DECLARE_FIELD(overview_levels) + .description("hnsw overview levels for feder") + .set_default(3) + .set_range(1, 5) + .for_feder(); + // this is a mostly debugging field + // todo: remove at later stages + KNOWHERE_CONFIG_DECLARE_FIELD(override_faiss_search) + .description("use knowhere-based search rather than faiss-based search") + .set_default(true) + .for_search(); + KNOWHERE_CONFIG_DECLARE_FIELD(refine) + .description("whether the refine is used during the train") + .set_default(false) + .for_train(); + KNOWHERE_CONFIG_DECLARE_FIELD(refine_k) + .description("refine k") + .allow_empty_without_default() + .set_range(1, std::numeric_limits::max()) + .for_search(); + } + + Status + CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override { + switch (param_type) { + case PARAM_TYPE::SEARCH: { + // validate ef + if (!ef.has_value()) { + ef = std::max(k.value(), kEfMinValue); + } else if (k.value() > ef.value()) { + *err_msg = "ef(" + std::to_string(ef.value()) + ") should be larger than k(" + + std::to_string(k.value()) + ")"; + LOG_KNOWHERE_ERROR_ << *err_msg; + return Status::out_of_range_in_json; + } + break; + } + case PARAM_TYPE::RANGE_SEARCH: { + if (!ef.has_value()) { + // if ef is not set by user, set it to default + ef = kDefaultRangeSearchEf; + } + break; + } + default: + break; + } + return Status::success; + } +}; + +class FaissHnswFlatConfig : public FaissHnswConfig { + public: + Status + CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override { + // check the base class + const auto base_status = FaissHnswConfig::CheckAndAdjust(param_type, err_msg); + if (base_status != Status::success) { + return base_status; + } + + // check our parameters + if (param_type == PARAM_TYPE::TRAIN) { + if (refine.value_or(false)) { + *err_msg = "refine is not currently supported for this index"; + LOG_KNOWHERE_ERROR_ << *err_msg; + return Status::invalid_value_in_json; + } + } + return Status::success; + } +}; + +} // namespace knowhere + +#endif /* FAISS_HNSW_CONFIG_H */ diff --git a/src/index/hnsw/impl/BitsetFilter.h b/src/index/hnsw/impl/BitsetFilter.h new file mode 100644 index 000000000..7b3e5b88b --- /dev/null +++ b/src/index/hnsw/impl/BitsetFilter.h @@ -0,0 +1,35 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include + +#include "knowhere/bitsetview.h" + +namespace knowhere { + +// specialized override for knowhere +struct BitsetFilter { + // contains disabled nodes. + knowhere::BitsetView bitset_view; + + inline BitsetFilter(knowhere::BitsetView bitset_view_) : bitset_view{bitset_view_} { + } + + inline bool + allowed(const faiss::idx_t idx) const { + // there's no check for bitset_view.empty() by design + return !bitset_view.test(idx); + } +}; + +} // namespace knowhere diff --git a/src/index/hnsw/impl/FederVisitor.h b/src/index/hnsw/impl/FederVisitor.h new file mode 100644 index 000000000..02bba0048 --- /dev/null +++ b/src/index/hnsw/impl/FederVisitor.h @@ -0,0 +1,49 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include + +#include "knowhere/feder/HNSW.h" + +namespace knowhere { + +// a default feder visitor +struct FederVisitor { + using storage_idx_t = faiss::HNSW::storage_idx_t; + + // a non-owning pointer + knowhere::feder::hnsw::FederResult* feder = nullptr; + + inline FederVisitor(knowhere::feder::hnsw::FederResult* const feder_v) : feder{feder_v} { + } + + // + inline void + visit_level(const int level) { + if (feder != nullptr) { + feder->visit_info_.AddLevelVisitRecord(level); + } + } + + // + inline void + visit_edge(const int level, const storage_idx_t node_from, const storage_idx_t node_to, const float distance) { + if (feder != nullptr) { + feder->visit_info_.AddVisitRecord(level, node_from, node_to, distance); + feder->id_set_.insert(node_from); + feder->id_set_.insert(node_to); + } + } +}; + +} // namespace knowhere diff --git a/src/index/hnsw/impl/IndexBruteForceWrapper.cc b/src/index/hnsw/impl/IndexBruteForceWrapper.cc new file mode 100644 index 000000000..ffd3f74cf --- /dev/null +++ b/src/index/hnsw/impl/IndexBruteForceWrapper.cc @@ -0,0 +1,80 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "index/hnsw/impl/IndexBruteForceWrapper.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "index/hnsw/impl/BitsetFilter.h" +#include "knowhere/bitsetview.h" +#include "knowhere/bitsetview_idselector.h" + +namespace knowhere { + +using idx_t = faiss::idx_t; + +// +IndexBruteForceWrapper::IndexBruteForceWrapper(faiss::Index* underlying_index) + : faiss::cppcontrib::knowhere::IndexWrapper{underlying_index} { +} + +void +IndexBruteForceWrapper::search(faiss::idx_t n, const float* __restrict x, faiss::idx_t k, float* __restrict distances, + faiss::idx_t* __restrict labels, const faiss::SearchParameters* params) const { + FAISS_THROW_IF_NOT(k > 0); + + std::unique_ptr dis(index->get_distance_computer()); + + // no parallelism by design + for (idx_t i = 0; i < n; i++) { + // prepare the query + dis->set_query(x + i * index->d); + + // allocate heap + idx_t* const local_ids = labels + i * index->d; + float* const local_distances = distances + i * index->d; + + // set up a filter + faiss::IDSelector* sel = (params == nullptr) ? nullptr : params->sel; + + // sel is assumed to be non-null + if (sel == nullptr) { + throw; + } + + // try knowhere-specific filter + const knowhere::BitsetViewIDSelector* bw_idselector = dynamic_cast(sel); + + knowhere::BitsetFilter filter(bw_idselector->bitset_view); + + if (is_similarity_metric(index->metric_type)) { + using C = faiss::CMin; + + faiss::cppcontrib::knowhere::brute_force_search_impl( + index->ntotal, *dis, filter, k, local_distances, local_ids); + } else { + using C = faiss::CMax; + + faiss::cppcontrib::knowhere::brute_force_search_impl( + index->ntotal, *dis, filter, k, local_distances, local_ids); + } + } +} + +} // namespace knowhere diff --git a/src/index/hnsw/impl/IndexBruteForceWrapper.h b/src/index/hnsw/impl/IndexBruteForceWrapper.h new file mode 100644 index 000000000..927869055 --- /dev/null +++ b/src/index/hnsw/impl/IndexBruteForceWrapper.h @@ -0,0 +1,29 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include + +namespace knowhere { + +// override a search procedure to perform a brute-force search. +struct IndexBruteForceWrapper : public faiss::cppcontrib::knowhere::IndexWrapper { + IndexBruteForceWrapper(faiss::Index* underlying_index); + + /// entry point for search + void + search(faiss::idx_t n, const float* x, faiss::idx_t k, float* distances, faiss::idx_t* labels, + const faiss::SearchParameters* params) const override; +}; + +} // namespace knowhere diff --git a/src/index/hnsw/impl/IndexHNSWWrapper.cc b/src/index/hnsw/impl/IndexHNSWWrapper.cc new file mode 100644 index 000000000..b8d22a31a --- /dev/null +++ b/src/index/hnsw/impl/IndexHNSWWrapper.cc @@ -0,0 +1,241 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "index/hnsw/impl/IndexHNSWWrapper.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "index/hnsw/impl/BitsetFilter.h" +#include "index/hnsw/impl/FederVisitor.h" +#include "knowhere/bitsetview.h" +#include "knowhere/bitsetview_idselector.h" + +#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) +#include "knowhere/prometheus_client.h" +#endif + +namespace knowhere { + +/************************************************************** + * Utilities + **************************************************************/ + +namespace { + +// cloned from IndexHNSW.cpp +faiss::DistanceComputer* +storage_distance_computer(const faiss::Index* storage) { + if (faiss::is_similarity_metric(storage->metric_type)) { + return new faiss::NegativeDistanceComputer(storage->get_distance_computer()); + } else { + return storage->get_distance_computer(); + } +} + +// a visitor that does nothing +struct DummyVisitor { + using storage_idx_t = faiss::HNSW::storage_idx_t; + + void + visit_level(const int level) { + // does nothing + } + + void + visit_edge(const int level, const storage_idx_t node_from, const storage_idx_t node_to, const float distance) { + // does nothing + } +}; + +} // namespace + +/************************************************************** + * IndexHNSWWrapper implementation + **************************************************************/ + +using idx_t = faiss::idx_t; + +IndexHNSWWrapper::IndexHNSWWrapper(faiss::IndexHNSW* underlying_index) + : faiss::cppcontrib::knowhere::IndexWrapper(underlying_index) { +} + +void +IndexHNSWWrapper::search(idx_t n, const float* __restrict x, idx_t k, float* __restrict distances, + idx_t* __restrict labels, const faiss::SearchParameters* params_in) const { + FAISS_THROW_IF_NOT(k > 0); + + const faiss::IndexHNSW* index_hnsw = dynamic_cast(index); + FAISS_THROW_IF_NOT(index_hnsw); + + FAISS_THROW_IF_NOT_MSG(index_hnsw->storage, "No storage index"); + + // set up + using C = faiss::HNSW::C; + + // check if the graph is empty + if (index_hnsw->hnsw.entry_point == -1) { + for (idx_t i = 0; i < k * n; i++) { + distances[i] = C::neutral(); + labels[i] = -1; + } + + return; + } + + // check parameters + const SearchParametersHNSWWrapper* params = nullptr; + const faiss::HNSW& hnsw = index_hnsw->hnsw; + + float kAlpha = 0.0f; + if (params_in) { + params = dynamic_cast(params_in); + FAISS_THROW_IF_NOT_MSG(params, "params type invalid"); + + kAlpha = params->kAlpha; + } + + // set up hnsw_stats + faiss::HNSWStats* __restrict const hnsw_stats = (params == nullptr) ? nullptr : params->hnsw_stats; + + // + size_t n1 = 0; + size_t n2 = 0; + size_t ndis = 0; + size_t nhops = 0; + + // + faiss::cppcontrib::knowhere::Bitset bitset_visited_nodes = + faiss::cppcontrib::knowhere::Bitset::create_uninitialized(index->ntotal); + + // create a distance computer + std::unique_ptr dis(storage_distance_computer(index_hnsw->storage)); + + // no parallelism by design + for (idx_t i = 0; i < n; i++) { + // prepare the query + dis->set_query(x + i * index->d); + + // prepare the table of visited elements + bitset_visited_nodes.clear(); + + // a visitor + knowhere::feder::hnsw::FederResult* feder = (params == nullptr) ? nullptr : params->feder; + + // future results + faiss::HNSWStats local_stats; + + // set up a filter + faiss::IDSelector* sel = (params == nullptr) ? nullptr : params->sel; + + // try knowhere-specific filter + const knowhere::BitsetViewIDSelector* bw_idselector = dynamic_cast(sel); + + if (bw_idselector == nullptr) { + // no filter + faiss::cppcontrib::knowhere::AllowAllFilter filter; + + // feder templating is important, bcz it removes an unneeded 'CALL' instruction. + if (feder == nullptr) { + // no feder + DummyVisitor graph_visitor; + + using searcher_type = + faiss::cppcontrib::knowhere::v2_hnsw_searcher; + + searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, filter, kAlpha, params}; + + local_stats = searcher.search(k, distances + i * k, labels + i * k); + } else { + // use feder + FederVisitor graph_visitor(feder); + + using searcher_type = + faiss::cppcontrib::knowhere::v2_hnsw_searcher; + + searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, filter, kAlpha, params}; + + local_stats = searcher.search(k, distances + i * k, labels + i * k); + } + } else { + // with filter + knowhere::BitsetFilter filter(bw_idselector->bitset_view); + + // feder templating is important, bcz it removes an unneeded 'CALL' instruction. + if (feder == nullptr) { + // no feder + DummyVisitor graph_visitor; + + using searcher_type = faiss::cppcontrib::knowhere::v2_hnsw_searcher< + faiss::DistanceComputer, DummyVisitor, faiss::cppcontrib::knowhere::Bitset, knowhere::BitsetFilter>; + + searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, filter, kAlpha, params}; + + local_stats = searcher.search(k, distances + i * k, labels + i * k); + } else { + // use feder + FederVisitor graph_visitor(feder); + + using searcher_type = faiss::cppcontrib::knowhere::v2_hnsw_searcher< + faiss::DistanceComputer, FederVisitor, faiss::cppcontrib::knowhere::Bitset, knowhere::BitsetFilter>; + + searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, filter, kAlpha, params}; + + local_stats = searcher.search(k, distances + i * k, labels + i * k); + } + } + + // record some statistics +#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) + knowhere::knowhere_hnsw_search_hops.Observe(local_stats.nhops); +#endif + + // update stats if possible + if (hnsw_stats != nullptr) { + n1 += local_stats.n1; + n2 += local_stats.n2; + ndis += local_stats.ndis; + nhops += local_stats.nhops; + } + } + + // update stats if possible + if (hnsw_stats != nullptr) { + hnsw_stats->combine({n1, n2, ndis, nhops}); + } + + // done, update the results, if needed + if (is_similarity_metric(index->metric_type)) { + // we need to revert the negated distances + for (idx_t i = 0; i < k * n; i++) { + distances[i] = -distances[i]; + } + } +} + +} // namespace knowhere diff --git a/src/index/hnsw/impl/IndexHNSWWrapper.h b/src/index/hnsw/impl/IndexHNSWWrapper.h new file mode 100644 index 000000000..5f51d0bba --- /dev/null +++ b/src/index/hnsw/impl/IndexHNSWWrapper.h @@ -0,0 +1,51 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include + +#include +#include + +#include "knowhere/feder/HNSW.h" + +namespace knowhere { + +// Custom parameters for IndexHNSW. +struct SearchParametersHNSWWrapper : public faiss::SearchParametersHNSW { + // Stats will be updated if the object pointer is provided. + faiss::HNSWStats* hnsw_stats = nullptr; + // feder will be updated if the object pointer is provided. + knowhere::feder::hnsw::FederResult* feder = nullptr; + // filtering parameter + float kAlpha = 1.0f; + + inline ~SearchParametersHNSWWrapper() { + } +}; + +// TODO: +// Please note that this particular searcher is int32_t based, so won't +// work correctly for 2B+ samples. This can be easily changed, if needed. + +// override a search() procedure for IndexHNSW. +struct IndexHNSWWrapper : public faiss::cppcontrib::knowhere::IndexWrapper { + IndexHNSWWrapper(faiss::IndexHNSW* underlying_index); + + /// entry point for search + void + search(faiss::idx_t n, const float* x, faiss::idx_t k, float* distances, faiss::idx_t* labels, + const faiss::SearchParameters* params) const override; +}; + +} // namespace knowhere diff --git a/src/index/hnsw/impl/IndexWrapperCosine.cc b/src/index/hnsw/impl/IndexWrapperCosine.cc new file mode 100644 index 000000000..e022d1028 --- /dev/null +++ b/src/index/hnsw/impl/IndexWrapperCosine.cc @@ -0,0 +1,29 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "index/hnsw/impl/IndexWrapperCosine.h" + +#include + +namespace knowhere { + +// a wrapper that overrides a distance computer +IndexWrapperCosine::IndexWrapperCosine(faiss::Index* index, const float* inverse_l2_norms_in) + : faiss::cppcontrib::knowhere::IndexWrapper(index), inverse_l2_norms{inverse_l2_norms_in} { +} + +faiss::DistanceComputer* +IndexWrapperCosine::get_distance_computer() const { + return new faiss::WithCosineNormDistanceComputer( + inverse_l2_norms, index->d, std::unique_ptr(index->get_distance_computer())); +} + +} // namespace knowhere diff --git a/src/index/hnsw/impl/IndexWrapperCosine.h b/src/index/hnsw/impl/IndexWrapperCosine.h new file mode 100644 index 000000000..eb048876a --- /dev/null +++ b/src/index/hnsw/impl/IndexWrapperCosine.h @@ -0,0 +1,32 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +namespace knowhere { + +// overrides a distance compute function +struct IndexWrapperCosine : public faiss::cppcontrib::knowhere::IndexWrapper { + // a non-owning pointer + const float* inverse_l2_norms; + + // norms are external + IndexWrapperCosine(faiss::Index* index, const float* inverse_l2_norms_in); + + faiss::DistanceComputer* + get_distance_computer() const override; +}; + +} // namespace knowhere diff --git a/thirdparty/faiss/benchs/bench_hnsw_knowhere.cpp b/thirdparty/faiss/benchs/bench_hnsw_knowhere.cpp new file mode 100644 index 000000000..09d70f14f --- /dev/null +++ b/thirdparty/faiss/benchs/bench_hnsw_knowhere.cpp @@ -0,0 +1,206 @@ +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +std::vector generate_dataset(const size_t n, const size_t d, uint64_t seed) { + std::default_random_engine rng(seed); + std::uniform_real_distribution u(-1, 1); + + std::vector data(n * d); + for (size_t i = 0; i < data.size(); i++) { + data[i] = u(rng); + } + + return data; +} + +float get_recall_rate( + const size_t nq, + const size_t k, + const std::vector& baseline, + const std::vector& candidate +) { + size_t n = 0; + for (size_t i = 0; i < nq; i++) { + std::unordered_set a_set(k * 4); + + for (size_t j = 0; j < k; j++) { + a_set.insert(baseline[i * k + j]); + } + + for (size_t j = 0; j < k; j++) { + auto itr = a_set.find(candidate[i * k + j]); + if (itr != a_set.cend()) { + n += 1; + } + } + } + + return (float)n / candidate.size(); +} + +struct StopWatch { + using timepoint_t = std::chrono::time_point; + timepoint_t Start; + + StopWatch() { + Start = std::chrono::steady_clock::now(); + } + + double elapsed() const { + const auto now = std::chrono::steady_clock::now(); + std::chrono::duration elapsed = now - Start; + return elapsed.count(); + } +}; + +void test(const size_t nt, const size_t d, const size_t nq, const size_t k) { + // generate a dataset for train + std::vector xt = generate_dataset(nt, d, 123); + + // create an baseline + std::unique_ptr baseline_index( + faiss::index_factory(d, "Flat", faiss::MetricType::METRIC_L2) + ); + baseline_index->train(nt, xt.data()); + baseline_index->add(nt, xt.data()); + + // create an hnsw index + std::unique_ptr hnsw_index( + faiss::index_factory(d, "HNSW32,Flat", faiss::MetricType::METRIC_L2) + ); + hnsw_index->train(nt, xt.data()); + hnsw_index->add(nt, xt.data()); + + // generate a query dataset + std::vector xq = generate_dataset(nq, d, 123); + + // a seed + std::default_random_engine rng(789); + + // perform evaluation with a different level of filtering + for (const size_t percent : {0, 1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 95, 99}) { + // generate a bitset with a given percentage + std::vector ids_to_use(nt); + std::iota(ids_to_use.begin(), ids_to_use.end(), 0); + + std::shuffle(ids_to_use.begin(), ids_to_use.end(), rng); + + // number of points to use + const size_t nt_real = size_t(std::max(1.0, nt - (nt * percent / 100.0))); + + // create a bitset + faiss::cppcontrib::knowhere::Bitset bitset = + faiss::cppcontrib::knowhere::Bitset::create_cleared(nt); + for (size_t i = 0; i < nt_real; i++) { + bitset.set(ids_to_use[i]); + } + + // create an IDSelector + faiss::IDSelectorBitmap sel(nt, bitset.bits.get()); + + // the quant of a search + const size_t nbatch = nq; + + // perform a baseline search + std::vector baseline_dis(k * nq, -1); + std::vector baseline_ids(k * nq, -1); + + faiss::SearchParameters baseline_params; + baseline_params.sel = &sel; + + StopWatch sw_baseline; + for (size_t p = 0; p < nq; p += nbatch) { + size_t p0 = std::min(nq, p + nbatch); + size_t np = p0 - p; + + baseline_index->search( + np, + xq.data() + p * d, + k, + baseline_dis.data() + k * p, + baseline_ids.data() + k * p, + &baseline_params); + } + double baseline_elapsed = sw_baseline.elapsed(); + + // perform an hnsw search + std::vector hnsw_dis(k * nq, -1); + std::vector hnsw_ids(k * nq, -1); + + faiss::SearchParametersHNSW hnsw_params; + hnsw_params.sel = &sel; + hnsw_params.efSearch = 64; + + StopWatch sw_hnsw; + for (size_t p = 0; p < nq; p += nbatch) { + size_t p0 = std::min(nq, p + nbatch); + size_t np = p0 - p; + + // hnsw_index->search(nq, xq.data(), k, hnsw_dis.data(), hnsw_ids.data(), &hnsw_params); + hnsw_index->search( + np, + xq.data() + p * d, + k, + hnsw_dis.data() + k * p, + hnsw_ids.data() + k * p, + &hnsw_params); + } + double hnsw_elapsed = sw_hnsw.elapsed(); + + // perform a cppcontrib/knowhere search + std::vector hnsw_candidate_dis(k * nq, -1); + std::vector hnsw_candidate_ids(k * nq, -1); + + faiss::cppcontrib::knowhere::SearchParametersHNSWWrapper hnsw_candidate_params; + hnsw_candidate_params.sel = &sel; + hnsw_candidate_params.kAlpha = ((float)nt_real / nt) * 0.7f; + hnsw_candidate_params.efSearch = 64; + + faiss::cppcontrib::knowhere::IndexHNSWWrapper wrapper( + dynamic_cast(hnsw_index.get())); + + StopWatch sw_hnsw_candidate; + for (size_t p = 0; p < nq; p += nbatch) { + size_t p0 = std::min(nq, p + nbatch); + size_t np = p0 - p; + + // wrapper.search(nq, xq.data(), k, hnsw_candidate_dis.data(), hnsw_candidate_ids.data(), &hnsw_candidate_params); + wrapper.search( + np, + xq.data() + p * d, + k, + hnsw_candidate_dis.data() + k * p, + hnsw_candidate_ids.data() + k * p, + &hnsw_candidate_params); + } + double hnsw_candidate_elapsed = sw_hnsw_candidate.elapsed(); + + // compute the recall rate + const float recall_hnsw = get_recall_rate(nq, k, baseline_ids, hnsw_ids); + const float recall_hnsw_candidate = get_recall_rate(nq, k, baseline_ids, hnsw_candidate_ids); + + printf("p, d=%zd, nt=%zd, nq=%zd, percent=%zd, recall_hnsw=%f, recall_hnsw_candidate=%f, timings %f %f %f\n", + d, nt, nq, percent, recall_hnsw, recall_hnsw_candidate, + baseline_elapsed, hnsw_elapsed, hnsw_candidate_elapsed); + } +} + +int main() { + test(65536 * 4, 128, 1024, 64); + + return 0; +} \ No newline at end of file diff --git a/thirdparty/faiss/faiss/IndexCosine.cpp b/thirdparty/faiss/faiss/IndexCosine.cpp new file mode 100644 index 000000000..df4d912cc --- /dev/null +++ b/thirdparty/faiss/faiss/IndexCosine.cpp @@ -0,0 +1,319 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace faiss { + +////////////////////////////////////////////////////////////////////////////////// + +// +struct FlatCosineDis : FlatCodesDistanceComputer { + size_t d; + idx_t nb; + const float* q; + const float* b; + size_t ndis; + + const float* inverse_l2_norms; + float inverse_query_norm = 0; + + float distance_to_code(const uint8_t* code) final { + ndis++; + const float norm = fvec_norm_L2sqr((const float*)code, d); + return (norm == 0) ? 0 : (fvec_inner_product(q, (const float*)code, d) / sqrtf(norm) * inverse_query_norm); + } + + float operator()(const idx_t i) final override { + const float* __restrict y_i = + reinterpret_cast(codes + i * code_size); + + prefetch_L2(inverse_l2_norms + i); + + const float dp0 = fvec_inner_product(q, y_i, d); + + const float inverse_code_norm_i = inverse_l2_norms[i]; + const float distance = dp0 * inverse_code_norm_i * inverse_query_norm; + return distance; + } + + float symmetric_dis(idx_t i, idx_t j) final override { + const float* __restrict y_i = + reinterpret_cast(codes + i * code_size); + const float* __restrict y_j = + reinterpret_cast(codes + j * code_size); + + prefetch_L2(inverse_l2_norms + i); + prefetch_L2(inverse_l2_norms + j); + + const float dp0 = fvec_inner_product(y_i, y_j, d); + + const float inverse_code_norm_i = inverse_l2_norms[i]; + const float inverse_code_norm_j = inverse_l2_norms[j]; + + return dp0 * inverse_code_norm_i * inverse_code_norm_j; + } + + explicit FlatCosineDis(const IndexFlatCosine& storage, const float* q = nullptr) + : FlatCodesDistanceComputer( + storage.codes.data(), + storage.code_size), + d(storage.d), + nb(storage.ntotal), + q(q), + b(storage.get_xb()), + ndis(0) { + // it is the caller's responsibility to ensure that everything is all right. + inverse_l2_norms = storage.get_inverse_l2_norms(); + + if (q != nullptr) { + const float query_l2norm = fvec_norm_L2sqr(q, d); + inverse_query_norm = (query_l2norm <= 0) ? 1.0f : (1.0f / sqrtf(query_l2norm)); + } else { + inverse_query_norm = 0; + } + } + + void set_query(const float* x) final override { + q = x; + + if (q != nullptr) { + const float query_l2norm = fvec_norm_L2sqr(q, d); + inverse_query_norm = (query_l2norm <= 0) ? 1.0f : (1.0f / sqrtf(query_l2norm)); + } else { + inverse_query_norm = 0; + } + } + + // compute four distances + void distances_batch_4( + const idx_t idx0, + const idx_t idx1, + const idx_t idx2, + const idx_t idx3, + float& dis0, + float& dis1, + float& dis2, + float& dis3) final override { + ndis += 4; + + // compute first, assign next + const float* __restrict y0 = + reinterpret_cast(codes + idx0 * code_size); + const float* __restrict y1 = + reinterpret_cast(codes + idx1 * code_size); + const float* __restrict y2 = + reinterpret_cast(codes + idx2 * code_size); + const float* __restrict y3 = + reinterpret_cast(codes + idx3 * code_size); + + prefetch_L2(inverse_l2_norms + idx0); + prefetch_L2(inverse_l2_norms + idx1); + prefetch_L2(inverse_l2_norms + idx2); + prefetch_L2(inverse_l2_norms + idx3); + + float dp0 = 0; + float dp1 = 0; + float dp2 = 0; + float dp3 = 0; + fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3); + + const float inverse_code_norm0 = inverse_l2_norms[idx0]; + const float inverse_code_norm1 = inverse_l2_norms[idx1]; + const float inverse_code_norm2 = inverse_l2_norms[idx2]; + const float inverse_code_norm3 = inverse_l2_norms[idx3]; + + dis0 = dp0 * inverse_code_norm0 * inverse_query_norm; + dis1 = dp1 * inverse_code_norm1 * inverse_query_norm; + dis2 = dp2 * inverse_code_norm2 * inverse_query_norm; + dis3 = dp3 * inverse_code_norm3 * inverse_query_norm; + } +}; + + +////////////////////////////////////////////////////////////////////////////////// + +// initialize in a custom way +WithCosineNormDistanceComputer::WithCosineNormDistanceComputer( + const float* inverse_l2_norms_, + const int d_, + std::unique_ptr&& basedis_) : +basedis(std::move(basedis_)), inverse_l2_norms{inverse_l2_norms_}, d{d_} {} + +// the query remains untouched. It is a caller's responsibility +// to normalize it. +void WithCosineNormDistanceComputer::set_query(const float* x) { + basedis->set_query(x); + + if (x != nullptr) { + const float query_l2norm = faiss::fvec_norm_L2sqr(x, d); + inverse_query_norm = (query_l2norm <= 0) ? 1.0f : (1.0f / sqrtf(query_l2norm)); + } else { + inverse_query_norm = 0; + } +} + +/// compute distance of vector i to current query +float WithCosineNormDistanceComputer::operator()(idx_t i) { + prefetch_L2(inverse_l2_norms + i); + + float dis = (*basedis)(i); + dis *= inverse_l2_norms[i] * inverse_query_norm; + + return dis; +} + +void WithCosineNormDistanceComputer::distances_batch_4( + const idx_t idx0, + const idx_t idx1, + const idx_t idx2, + const idx_t idx3, + float& dis0, + float& dis1, + float& dis2, + float& dis3) { + prefetch_L2(inverse_l2_norms + idx0); + prefetch_L2(inverse_l2_norms + idx1); + prefetch_L2(inverse_l2_norms + idx2); + prefetch_L2(inverse_l2_norms + idx3); + + basedis->distances_batch_4( + idx0, idx1, idx2, idx3, dis0, dis1, dis2, dis3); + + dis0 = dis0 * inverse_l2_norms[idx0] * inverse_query_norm; + dis1 = dis1 * inverse_l2_norms[idx1] * inverse_query_norm; + dis2 = dis2 * inverse_l2_norms[idx2] * inverse_query_norm; + dis3 = dis3 * inverse_l2_norms[idx3] * inverse_query_norm; +} + +/// compute distance between two stored vectors +float WithCosineNormDistanceComputer::symmetric_dis(idx_t i, idx_t j) { + prefetch_L2(inverse_l2_norms + i); + prefetch_L2(inverse_l2_norms + j); + + float v = basedis->symmetric_dis(i, j); + v *= inverse_l2_norms[i]; + v *= inverse_l2_norms[j]; + return v; +} + + +////////////////////////////////////////////////////////////////////////////////// + +L2NormsStorage L2NormsStorage::from_l2_norms(const std::vector& l2_norms) { + L2NormsStorage result; + result.add_l2_norms(l2_norms.data(), l2_norms.size()); + return result; +} + +void L2NormsStorage::add(const float* x, const idx_t n, const idx_t d) { + const size_t current_size = inverse_l2_norms.size(); + inverse_l2_norms.resize(current_size + n); + + for (idx_t i = 0; i < n; i++) { + const float l2sqr_norm = fvec_norm_L2sqr(x + i * d, d); + const float inverse_l2_norm = (l2sqr_norm == 0.0f) ? 1.0f : (1.0f / sqrtf(l2sqr_norm)); + inverse_l2_norms[i + current_size] = inverse_l2_norm; + } +} + +void L2NormsStorage::add_l2_norms(const float* l2_norms, const idx_t n) { + const size_t current_size = inverse_l2_norms.size(); + inverse_l2_norms.resize(current_size + n); + for (idx_t i = 0; i < n; i++) { + const float l2sqr_norm = l2_norms[i]; + const float inverse_l2_norm = (l2sqr_norm == 0.0f) ? 1.0f : (1.0f / l2sqr_norm); + inverse_l2_norms[i + current_size] = inverse_l2_norm; + } +} + +void L2NormsStorage::reset() { + inverse_l2_norms.clear(); +} + +std::vector L2NormsStorage::as_l2_norms() const { + std::vector result(inverse_l2_norms.size()); + for (size_t i = 0; i < inverse_l2_norms.size(); i++) { + result[i] = 1.0f / inverse_l2_norms[i]; + } + + return result; +} + + +////////////////////////////////////////////////////////////////////////////////// + +// +IndexFlatCosine::IndexFlatCosine() : IndexFlat() { + metric_type = MetricType::METRIC_INNER_PRODUCT; + is_cosine = true; +} + +// +IndexFlatCosine::IndexFlatCosine(idx_t d) : IndexFlat(d, MetricType::METRIC_INNER_PRODUCT, true) {} + +// +void IndexFlatCosine::add(idx_t n, const float* x) { + FAISS_THROW_IF_NOT(is_trained); + if (n == 0) { + return; + } + + // todo aguzhva: + // it is a tricky situation at this moment, because IndexFlatCosine + // contains duplicate norms (one in IndexFlatCodes and one in HasInverseL2Norms). + // Norms in IndexFlatCodes are going to be removed in the future. + IndexFlat::add(n, x); + inverse_norms_storage.add(x, n, d); +} + +void IndexFlatCosine::reset() { + IndexFlat::reset(); + inverse_norms_storage.reset(); +} + +const float* IndexFlatCosine::get_inverse_l2_norms() const { + return inverse_norms_storage.inverse_l2_norms.data(); +} + +// +FlatCodesDistanceComputer* IndexFlatCosine::get_FlatCodesDistanceComputer() const { + return new FlatCosineDis(*this); +} + + +////////////////////////////////////////////////////////////////////////////////// + +// +IndexHNSWFlatCosine::IndexHNSWFlatCosine() { + is_trained = true; +} + +IndexHNSWFlatCosine::IndexHNSWFlatCosine(int d, int M) : + IndexHNSW(new IndexFlatCosine(d), M) +{ + own_fields = true; + is_trained = true; +} + + +} + diff --git a/thirdparty/faiss/faiss/IndexCosine.h b/thirdparty/faiss/faiss/IndexCosine.h new file mode 100644 index 000000000..0b952fd3d --- /dev/null +++ b/thirdparty/faiss/faiss/IndexCosine.h @@ -0,0 +1,110 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +// knowhere-specific indices + +#pragma once + +#include +#include +#include + + +namespace faiss { + +// a distance computer wrapper that normalizes the distance over a query +struct WithCosineNormDistanceComputer : DistanceComputer { + /// owned by this + std::unique_ptr basedis; + // not owned by this + const float* inverse_l2_norms = nullptr; + // computed internally + float inverse_query_norm = 0; + // cached dimensionality + int d = 0; + + // initialize in a custom way + WithCosineNormDistanceComputer( + const float* inverse_l2_norms_, + const int d_, + std::unique_ptr&& basedis_); + + // the query remains untouched. It is a caller's responsibility + // to normalize it. + void set_query(const float* x) override; + + /// compute distance of vector i to current query + float operator()(idx_t i) override; + + void distances_batch_4( + const idx_t idx0, + const idx_t idx1, + const idx_t idx2, + const idx_t idx3, + float& dis0, + float& dis1, + float& dis2, + float& dis3) override; + + /// compute distance between two stored vectors + float symmetric_dis(idx_t i, idx_t j) override; +}; + +struct HasInverseL2Norms { + virtual ~HasInverseL2Norms() = default; + + virtual const float* get_inverse_l2_norms() const = 0; +}; + +// a supporting storage for L2 norms +struct L2NormsStorage { + std::vector inverse_l2_norms; + + // create from a vector of L2 norms (sqrt(sum(x^2))) + static L2NormsStorage from_l2_norms(const std::vector& l2_norms); + + // add vectors + void add(const float* x, const idx_t n, const idx_t d); + + // add L2 norms (sqrt(sum(x^2))) + void add_l2_norms(const float* l2_norms, const idx_t n); + + // clear the storage + void reset(); + + // produces a vector of L2 norms, effectively inverting inverse_l2_norms + std::vector as_l2_norms() const; +}; + +// A dedicated index used for Cosine Distance in the future. +struct IndexFlatCosine : IndexFlat, HasInverseL2Norms { + L2NormsStorage inverse_norms_storage; + + IndexFlatCosine(); + IndexFlatCosine(idx_t d); + + void add(idx_t n, const float* x) override; + void reset() override; + + FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override; + + const float* get_inverse_l2_norms() const override; +}; + +// +struct IndexHNSWFlatCosine : IndexHNSW { + IndexHNSWFlatCosine(); + IndexHNSWFlatCosine(int d, int M); +}; + + + +} diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.cpp b/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.cpp new file mode 100644 index 000000000..4c6cad23d --- /dev/null +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.cpp @@ -0,0 +1,103 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +IndexBruteForceWrapper::IndexBruteForceWrapper(Index* underlying_index) : + IndexWrapper{underlying_index} {} + +void IndexBruteForceWrapper::search( + idx_t n, + const float* __restrict x, + idx_t k, + float* __restrict distances, + idx_t* __restrict labels, + const SearchParameters* params +) const { + FAISS_THROW_IF_NOT(k > 0); + + idx_t check_period = InterruptCallback::get_period_hint( + index->d * index->ntotal); + + for (idx_t i0 = 0; i0 < n; i0 += check_period) { + idx_t i1 = std::min(i0 + check_period, n); + +#pragma omp parallel if (i1 - i0 > 1) + { + std::unique_ptr dis(index->get_distance_computer()); + +#pragma omp for schedule(guided) + for (idx_t i = i0; i < i1; i++) { + // prepare the query + dis->set_query(x + i * index->d); + + // allocate heap + idx_t* const local_ids = labels + i * index->d; + float* const local_distances = distances + i * index->d; + + // set up a filter + IDSelector* sel = (params == nullptr) ? nullptr : params->sel; + + // template, just in case a filter type will be specialized + // in order to remove virtual function call overhead. + using filter_type = DefaultIDSelectorFilter; + filter_type filter(sel); + + if (is_similarity_metric(index->metric_type)) { + using C = CMin; + + brute_force_search_impl( + index->ntotal, + *dis, + filter, + k, + local_distances, + local_ids + ); + } else { + using C = CMax; + + brute_force_search_impl( + index->ntotal, + *dis, + filter, + k, + local_distances, + local_ids + ); + } + } + } + + InterruptCallback::check(); + } +} + +} +} +} diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.h b/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.h new file mode 100644 index 000000000..2d7762921 --- /dev/null +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.h @@ -0,0 +1,39 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include + +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +// override a search procedure to perform a brute-force search. +struct IndexBruteForceWrapper : IndexWrapper { + IndexBruteForceWrapper(Index* underlying_index); + + /// entry point for search + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + const SearchParameters* params + ) const override; +}; + +} +} +} diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexHNSWWrapper.cpp b/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexHNSWWrapper.cpp new file mode 100644 index 000000000..6d56383e7 --- /dev/null +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexHNSWWrapper.cpp @@ -0,0 +1,212 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +// a visitor that does nothing +struct DummyVisitor { + using storage_idx_t = HNSW::storage_idx_t; + + void visit_level(const int level) { + // does nothing + } + + void visit_edge( + const int level, + const storage_idx_t node_from, + const storage_idx_t node_to, + const float distance + ) { + // does nothing + } +}; + + +/************************************************************** + * Utilities + **************************************************************/ + +namespace { + +// cloned from IndexHNSW.cpp +DistanceComputer* storage_distance_computer(const Index* storage) { + if (is_similarity_metric(storage->metric_type)) { + return new NegativeDistanceComputer(storage->get_distance_computer()); + } else { + return storage->get_distance_computer(); + } +} + +} + + +/************************************************************** + * IndexHNSWWrapper implementation + **************************************************************/ + +IndexHNSWWrapper::IndexHNSWWrapper(IndexHNSW* underlying_index) : + IndexWrapper(underlying_index) {} + +void IndexHNSWWrapper::search( + idx_t n, + const float* __restrict x, + idx_t k, + float* __restrict distances, + idx_t* __restrict labels, + const SearchParameters* params_in +) const { + FAISS_THROW_IF_NOT(k > 0); + + const IndexHNSW* index_hnsw = dynamic_cast(index); + FAISS_THROW_IF_NOT(index_hnsw); + + FAISS_THROW_IF_NOT_MSG( + index_hnsw->storage, + "No storage index"); + + // set up + using C = HNSW::C; + + // check if the graph is empty + if (index_hnsw->hnsw.entry_point == -1) { + for (idx_t i = 0; i < k * n; i++) { + distances[i] = C::neutral(); + labels[i] = -1; + } + + return; + } + + // check parameters + const SearchParametersHNSWWrapper* params = nullptr; + const HNSW& hnsw = index_hnsw->hnsw; + + float kAlpha = 0.0f; + int efSearch = hnsw.efSearch; + if (params_in) { + params = dynamic_cast(params_in); + FAISS_THROW_IF_NOT_MSG(params, "params type invalid"); + efSearch = params->efSearch; + kAlpha = params->kAlpha; + } + + // set up hnsw_stats + HNSWStats* __restrict const hnsw_stats = + (params == nullptr) ? nullptr : params->hnsw_stats; + + // + size_t n1 = 0; + size_t n2 = 0; + size_t ndis = 0; + size_t nhops = 0; + + idx_t check_period = InterruptCallback::get_period_hint( + hnsw.max_level * index->d * efSearch); + + for (idx_t i0 = 0; i0 < n; i0 += check_period) { + idx_t i1 = std::min(i0 + check_period, n); + +#pragma omp parallel if (i1 - i0 > 1) + { + Bitset bitset_visited_nodes = Bitset::create_uninitialized(index->ntotal); + + // create a distance computer + std::unique_ptr dis( + storage_distance_computer(index_hnsw->storage)); + +#pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided) + for (idx_t i = i0; i < i1; i++) { + // prepare the query + dis->set_query(x + i * index->d); + + // prepare the table of visited elements + bitset_visited_nodes.clear(); + + // a visitor + DummyVisitor graph_visitor; + + // future results + HNSWStats local_stats; + + // set up a filter + IDSelector* sel = (params == nullptr) ? nullptr : params->sel; + + // template, just in case a filter type will be specialized + // in order to remove virtual function call overhead. + using filter_type = DefaultIDSelectorFilter; + filter_type filter(sel); + + using searcher_type = v2_hnsw_searcher< + DistanceComputer, + DummyVisitor, + Bitset, + filter_type>; + + searcher_type searcher{ + hnsw, + *(dis.get()), + graph_visitor, + bitset_visited_nodes, + filter, + kAlpha, + params + }; + + local_stats = searcher.search(k, distances + i * k, labels + i * k); + + // update stats if possible + if (hnsw_stats != nullptr) { + n1 += local_stats.n1; + n2 += local_stats.n2; + ndis += local_stats.ndis; + nhops += local_stats.nhops; + } + } + } + + InterruptCallback::check(); + } + + // update stats if possible + if (hnsw_stats != nullptr) { + hnsw_stats->combine({n1, n2, ndis, nhops}); + } + + // done, update the results, if needed + if (is_similarity_metric(index->metric_type)) { + // we need to revert the negated distances + for (idx_t i = 0; i < k * n; i++) { + distances[i] = -distances[i]; + } + } +} + +} +} +} diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexHNSWWrapper.h b/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexHNSWWrapper.h new file mode 100644 index 000000000..45177f5c3 --- /dev/null +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexHNSWWrapper.h @@ -0,0 +1,56 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include + +#include + +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +// Custom parameters for IndexHNSW. +struct SearchParametersHNSWWrapper : public faiss::SearchParametersHNSW { + // Stats will be updated if the object pointer is provided. + faiss::HNSWStats* hnsw_stats = nullptr; + // filtering parameter. A floating point value within [0.0f, 1.0f range] + float kAlpha = 0.0f; + + inline ~SearchParametersHNSWWrapper() {} +}; + +// TODO: +// Please note that this particular searcher is int32_t based, so won't +// work correctly for 2B+ samples. This can be easily changed, if needed. + +// override a search() procedure for IndexHNSW. +struct IndexHNSWWrapper : IndexWrapper { + IndexHNSWWrapper(faiss::IndexHNSW* underlying_index); + + /// entry point for search + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + const SearchParameters* params = nullptr + ) const override; +}; + +} // namespace knowhere +} +} diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexWrapper.cpp b/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexWrapper.cpp new file mode 100644 index 000000000..e581ea346 --- /dev/null +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexWrapper.cpp @@ -0,0 +1,65 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +IndexWrapper::IndexWrapper(Index* underlying_index) : + Index{underlying_index->d, underlying_index->metric_type}, + index{underlying_index} +{ + ntotal = underlying_index->ntotal; + is_trained = underlying_index->is_trained; + verbose = underlying_index->verbose; + metric_arg = underlying_index->metric_arg; +} + +IndexWrapper::~IndexWrapper() {} + +void IndexWrapper::train(idx_t n, const float* x) { + index->train(n, x); + is_trained = index->is_trained; +} + +void IndexWrapper::add(idx_t n, const float* x) { + index->add(n, x); + this->ntotal = index->ntotal; +} + +void IndexWrapper::search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + const SearchParameters* params) const { + index->search(n, x, k, distances, labels, params); +} + +void IndexWrapper::reset() { + index->reset(); + this->ntotal = 0; +} + +void IndexWrapper::merge_from(Index& otherIndex, idx_t add_id) { + index->merge_from(otherIndex, add_id); +} + +DistanceComputer* IndexWrapper::get_distance_computer() const { + return index->get_distance_computer(); +} + +} +} +} \ No newline at end of file diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexWrapper.h b/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexWrapper.h new file mode 100644 index 000000000..b160d3b32 --- /dev/null +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexWrapper.h @@ -0,0 +1,51 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +// This index is useful for overriding a certain base functionality +// on the fly. +struct IndexWrapper : Index { + // a non-owning pointer + Index* index = nullptr; + + explicit IndexWrapper(Index* underlying_index); + + virtual ~IndexWrapper(); + + void train(idx_t n, const float* x) override; + + void add(idx_t n, const float* x) override; + + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + const SearchParameters* params = nullptr) const override; + + void reset() override; + + void merge_from(Index& otherIndex, idx_t add_id = 0) override; + + DistanceComputer* get_distance_computer() const override; +}; + +} +} +} diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Bruteforce.h b/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Bruteforce.h new file mode 100644 index 000000000..f13a35740 --- /dev/null +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Bruteforce.h @@ -0,0 +1,73 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +// C is CMax<> or CMin<> +template +void brute_force_search_impl( + const idx_t ntotal, + DistanceComputerT& qdis, + const FilterT filter, + const idx_t k, + float* __restrict distances, + idx_t* __restrict labels +) { + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + auto max_heap = std::make_unique[]>(k); + idx_t n_added = 0; + for (idx_t idx = 0; idx < ntotal; ++idx) { + if (filter.allowed(idx)) { + const float distance = qdis(idx); + if (n_added < k) { + n_added += 1; + heap_push(n_added, max_heap.get(), distance, idx); + } else if (C::cmp(max_heap[0].first, distance)) { + heap_replace_top(k, max_heap.get(), distance, idx); + } + } + } + + const idx_t len = std::min(n_added, idx_t(k)); + for (idx_t i = 0; i < len; i++) { + labels[len - i - 1] = max_heap[0].second; + distances[len - i - 1] = max_heap[0].first; + + heap_pop(len - i, max_heap.get()); + } + + // fill leftovers + if (len < k) { + for (idx_t idx = len; idx < k; idx++) { + labels[idx] = -1; + distances[idx] = C::neutral(); + } + } +} + +} +} +} diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Filters.h b/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Filters.h new file mode 100644 index 000000000..fce3a673b --- /dev/null +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Filters.h @@ -0,0 +1,44 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +// a filter that selects according to an IDSelector +template +struct DefaultIDSelectorFilter { + // contains enabled nodes. + // a non-owning pointer. + const IDSelectorT* selector = nullptr; + + inline DefaultIDSelectorFilter(const IDSelectorT* selector_) : selector{selector_} {} + + inline bool allowed(const idx_t idx) const { + return ((selector == nullptr) || (selector->is_member(idx))); + } +}; + +// a filter that allows everything +struct AllowAllFilter { + constexpr inline bool allowed(const idx_t) const { + return true; + } +}; + + +} +} +} diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/HnswSearcher.h b/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/HnswSearcher.h new file mode 100644 index 000000000..29501bfab --- /dev/null +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/HnswSearcher.h @@ -0,0 +1,415 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +// standard headers +#include +#include +#include +#include +#include + +// Faiss-specific headers +#include +#include +#include +#include +#include +#include +#include + +// Knowhere-specific headers +#include + + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +namespace { + +// whether to track statistics +constexpr bool track_hnsw_stats = true; + +} + +// Accomodates all the search logic and variables. +/// * DistanceComputerT is responsible for computing distances +/// * GraphVisitorT records visited edges +/// * VisitedT is responsible for tracking visited nodes +/// * FilterT is resposible for filtering unneeded nodes +/// Interfaces of all templates are tweaked to accept standard Faiss structures +/// with dynamic dispatching. Custom Knowhere structures are also accepted. +template +struct v2_hnsw_searcher { + using storage_idx_t = faiss::HNSW::storage_idx_t; + using idx_t = faiss::idx_t; + + // hnsw structure. + // the reference is not owned. + const faiss::HNSW& hnsw; + + // computes distances. it already knows the query vector. + // the reference is not owned. + DistanceComputerT& qdis; + + // records visited edges. + // the reference is not owned. + GraphVisitorT& graph_visitor; + + // tracks the nodes that have been visited already. + // the reference is not owned. + VisitedT& visited_nodes; + + // a filter for disabled nodes. + // the reference is not owned. + const FilterT filter; + + // parameter for the filtering + const float kAlpha; + + // custom parameters of HNSW search. + // the pointer is not owned. + const faiss::SearchParametersHNSW* params; + + // + v2_hnsw_searcher( + const faiss::HNSW& hnsw_, + DistanceComputerT& qdis_, + GraphVisitorT& graph_visitor_, + VisitedT& visited_nodes_, + const FilterT& filter_, + const float kAlpha_, + const faiss::SearchParametersHNSW* params_) : + hnsw{hnsw_}, + qdis{qdis_}, + graph_visitor{graph_visitor_}, + visited_nodes{visited_nodes_}, + filter{filter_}, + kAlpha{kAlpha_}, + params{params_} {} + + v2_hnsw_searcher(const v2_hnsw_searcher&) = delete; + v2_hnsw_searcher(v2_hnsw_searcher&&) = delete; + v2_hnsw_searcher& operator =(const v2_hnsw_searcher&) = delete; + v2_hnsw_searcher& operator =(v2_hnsw_searcher&&) = delete; + + // greedily update a nearest vector at a given level. + // * the update starts from the value in 'nearest'. + faiss::HNSWStats greedy_update_nearest( + const int level, + storage_idx_t& nearest, + float& d_nearest + ) { + faiss::HNSWStats stats; + + for (;;) { + storage_idx_t prev_nearest = nearest; + + size_t begin = 0; + size_t end = 0; + hnsw.neighbor_range(nearest, level, &begin, &end); + + // prefetch and eval the size + size_t count = 0; + for (size_t i = begin; i < end; i++) { + storage_idx_t v = hnsw.neighbors[i]; + if (v < 0) { + break; + } + + // qdis.prefetch(v); + count += 1; + } + + // visit neighbors + for (size_t i = begin; i < begin + count; i++) { + storage_idx_t v = hnsw.neighbors[i]; + + // compute the distance + const float dis = qdis(v); + + // record a traversed edge + graph_visitor.visit_edge(level, prev_nearest, nearest, dis); + + // check if an update is needed + if (dis < d_nearest) { + nearest = v; + d_nearest = dis; + } + } + + // update stats + if (track_hnsw_stats) { + stats.ndis += count; + stats.nhops += 1; + } + + // we're done if there we no changes + if (nearest == prev_nearest) { + return stats; + } + } + } + + // no loops, just check neighbors of a single node. + template + faiss::HNSWStats evaluate_single_node( + const idx_t node_id, + const int level, + float& accumulated_alpha, + FuncAddCandidate func_add_candidate + ) { + // // unused + // bool do_dis_check = params ? params->check_relative_distance + // : hnsw.check_relative_distance; + + faiss::HNSWStats stats; + + size_t begin = 0; + size_t end = 0; + hnsw.neighbor_range(node_id, level, &begin, &end); + + // todo: add prefetch + size_t counter = 0; + size_t saved_indices[4]; + int saved_statuses[4]; + + size_t ndis = 0; + for (size_t j = begin; j < end; j++) { + const storage_idx_t v1 = hnsw.neighbors[j]; + + if (v1 < 0) { + // no more neighbors + break; + } + + // already visited? + if (visited_nodes.get(v1)) { + // yes, visited. + graph_visitor.visit_edge(level, node_id, v1, -1); + continue; + } + + // not visited. mark as visited. + visited_nodes.set(v1); + + // is the node disabled? + int status = knowhere::Neighbor::kValid; + if (!filter.allowed(v1)) { + // yes, disabled + status = knowhere::Neighbor::kInvalid; + + // sometimes, disabled nodes are allowed to be used + accumulated_alpha += kAlpha; + if (accumulated_alpha < 1.0f) { + continue; + } + + accumulated_alpha -= 1.0f; + } + + saved_indices[counter] = v1; + saved_statuses[counter] = status; + counter += 1; + + ndis += 1; + + if (counter == 4) { + // evaluate 4x distances at once + float dis[4] = {0, 0, 0, 0}; + qdis.distances_batch_4( + saved_indices[0], + saved_indices[1], + saved_indices[2], + saved_indices[3], + dis[0], + dis[1], + dis[2], + dis[3]); + + for (size_t id4 = 0; id4 < 4; id4++) { + // record a traversed edge + graph_visitor.visit_edge(level, node_id, saved_indices[id4], dis[id4]); + + // add a record of visited nodes + knowhere::Neighbor nn(saved_indices[id4], dis[id4], saved_statuses[id4]); + if (func_add_candidate(nn)) { + #if defined(USE_PREFETCH) + // TODO + // _mm_prefetch(get_linklist0(v), _MM_HINT_T0); + #endif + } + } + + counter = 0; + } + } + + // process leftovers + for (size_t id4 = 0; id4 < counter; id4++) { + // evaluate a single distance + const float dis = qdis(saved_indices[id4]); + + // record a traversed edge + graph_visitor.visit_edge(level, node_id, saved_indices[id4], dis); + + // add a record of visited + knowhere::Neighbor nn(saved_indices[id4], dis, saved_statuses[id4]); + if (func_add_candidate(nn)) { +#if defined(USE_PREFETCH) + // TODO + // _mm_prefetch(get_linklist0(v), _MM_HINT_T0); +#endif + } + } + + // update stats + if (track_hnsw_stats) { + stats.ndis = ndis; + stats.nhops = 1; + } + + // done + return stats; + } + + // perform the search on a given level. + // it is assumed that retset is initialized and contains the initial nodes. + faiss::HNSWStats search_on_a_level( + knowhere::NeighborSetDoublePopList& retset, + const int level, + knowhere::IteratorMinHeap* const __restrict disqualified = nullptr, + const float initial_accumulated_alpha = 1.0f + ) { + faiss::HNSWStats stats; + + // + float accumulated_alpha = initial_accumulated_alpha; + + // what to do with a accepted candidate + auto add_search_candidate = [&](const knowhere::Neighbor n) { + return retset.insert(n, disqualified); + }; + + // iterate while possible + while (retset.has_next()) { + // get a node to be processed + const knowhere::Neighbor neighbor = retset.pop(); + + // analyze its neighbors + faiss::HNSWStats local_stats = evaluate_single_node( + neighbor.id, level, accumulated_alpha, add_search_candidate + ); + + // update stats + if (track_hnsw_stats) { + stats.combine(local_stats); + } + } + + // done + return stats; + } + + // perform the search. + faiss::HNSWStats search( + const idx_t k, + float* __restrict distances, + idx_t* __restrict labels + ) { + faiss::HNSWStats stats; + + // is the graph empty? + if (hnsw.entry_point == -1) { + return stats; + } + + // grab some needed parameters + const int efSearch = params ? params->efSearch : hnsw.efSearch; + + // greedy search on upper levels? + if (hnsw.upper_beam != 1) { + FAISS_THROW_MSG("Not implemented"); + return {}; + } + + // yes. + // greedy search on upper levels. + + // todo: first, check LRU cache + + // initialize the starting point. + storage_idx_t nearest = hnsw.entry_point; + float d_nearest = qdis(nearest); + + // iterate through upper levels + for (int level = hnsw.max_level; level >= 1; level--) { + // update the visitor + graph_visitor.visit_level(level); + + // alter the value of 'nearest' + faiss::HNSWStats local_stats = greedy_update_nearest( + level, + nearest, + d_nearest + ); + + // update stats + if (track_hnsw_stats) { + stats.combine(local_stats); + } + } + + // level 0 search + + // update the visitor + graph_visitor.visit_level(0); + + // initialize the container for candidates + const idx_t n_candidates = std::max((idx_t)efSearch, k); + knowhere::NeighborSetDoublePopList retset(n_candidates); + + // initialize retset with a single 'nearest' point + { + if (!filter.allowed(nearest)) { + retset.insert(knowhere::Neighbor(nearest, d_nearest, knowhere::Neighbor::kInvalid)); + } else { + retset.insert(knowhere::Neighbor(nearest, d_nearest, knowhere::Neighbor::kValid)); + } + + visited_nodes[nearest] = true; + } + + // perform the search of the level 0. + faiss::HNSWStats local_stats = search_on_a_level(retset, 0); + + // populate the result + const idx_t len = std::min((idx_t)retset.size(), k); + for (idx_t i = 0; i < len; i++) { + distances[i] = retset[i].distance; + labels[i] = (idx_t)retset[i].id; + } + + // update stats + if (track_hnsw_stats) { + stats.combine(local_stats); + } + + // done + return stats; + } +}; + +} +} +} diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Neighbor.h b/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Neighbor.h new file mode 100644 index 000000000..56e351c0a --- /dev/null +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Neighbor.h @@ -0,0 +1,229 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +struct Neighbor { + static constexpr int kChecked = 0; + static constexpr int kValid = 1; + static constexpr int kInvalid = 2; + + unsigned id; + float distance; + int status; + + Neighbor() = default; + Neighbor(unsigned id, float distance, int status) : id{id}, distance{distance}, status(status) {} + + inline bool + operator<(const Neighbor& other) const { + return distance < other.distance; + } + + inline bool + operator>(const Neighbor& other) const { + return distance > other.distance; + } +}; + +using IteratorMinHeap = std::priority_queue, std::greater>; + +template +class NeighborSetPopList { + private: + inline void + insert_helper(const Neighbor& nbr, size_t pos) { + // move + std::memmove(&data_[pos + 1], &data_[pos], (size_ - pos) * sizeof(Neighbor)); + if (size_ < capacity_) { + size_++; + } + + // insert + data_[pos] = nbr; + } + + public: + explicit NeighborSetPopList(size_t capacity) : capacity_(capacity), data_(capacity + 1) {} + + inline bool + insert(const Neighbor nbr, IteratorMinHeap* disqualified = nullptr) { + auto pos = std::upper_bound(&data_[0], &data_[0] + size_, nbr) - &data_[0]; + if (pos >= capacity_) { + if (disqualified) { + disqualified->push(nbr); + } + return false; + } + if (size_ == capacity_ && disqualified) { + disqualified->push(data_[size_ - 1]); + } + insert_helper(nbr, pos); + if constexpr (need_save) { + if (pos < cur_) { + cur_ = pos; + } + } + return true; + } + + inline auto + pop() -> Neighbor { + auto ret = data_[cur_]; + if constexpr (need_save) { + data_[cur_].status = Neighbor::kChecked; + cur_++; + while (cur_ < size_ && data_[cur_].status == Neighbor::kChecked) { + cur_++; + } + } else { + if (size_ > 1) { + std::memmove(&data_[0], &data_[1], (size_ - 1) * sizeof(Neighbor)); + } + size_--; + } + return ret; + } + + inline auto + has_next() const -> bool { + if constexpr (need_save) { + return cur_ < size_; + } else { + return size_ > 0; + } + } + + inline auto + size() const -> size_t { + return size_; + } + + inline auto + cur() const -> const Neighbor& { + if constexpr (need_save) { + return data_[cur_]; + } else { + return data_[0]; + } + } + + inline auto + at_search_back_dist() const -> float { + if (size_ < capacity_) { + return std::numeric_limits::max(); + } + return data_[capacity_ - 1].distance; + } + + void + clear() { + size_ = 0; + cur_ = 0; + } + + inline const Neighbor& + operator[](size_t i) { + return data_[i]; + } + + private: + size_t capacity_ = 0, size_ = 0, cur_ = 0; + std::vector data_; +}; + +class NeighborSetDoublePopList { + public: + explicit NeighborSetDoublePopList(size_t capacity = 0) { + valid_ns_ = std::make_unique>(capacity); + invalid_ns_ = std::make_unique>(capacity); + } + + // will push any neighbor that does not fit into NeighborSet to disqualified. + // When searching for iterator, those points removed from NeighborSet may be + // qualified candidates as the iterator iterates, thus we need to retain + // instead of disposing them. + bool + insert(const Neighbor& nbr, IteratorMinHeap* disqualified = nullptr) { + if (nbr.status == Neighbor::kValid) { + return valid_ns_->insert(nbr, disqualified); + } else { + if (nbr.distance < valid_ns_->at_search_back_dist()) { + return invalid_ns_->insert(nbr, disqualified); + } else if (disqualified) { + disqualified->push(nbr); + } + } + return false; + } + auto + pop() -> Neighbor { + return pop_based_on_distance(); + } + + auto + has_next() const -> bool { + return valid_ns_->has_next() || + (invalid_ns_->has_next() && invalid_ns_->cur().distance < valid_ns_->at_search_back_dist()); + } + + inline const Neighbor& + operator[](size_t i) { + return (*valid_ns_)[i]; + } + + inline size_t + size() const { + return valid_ns_->size(); + } + + private: + auto + pop_based_on_distance() -> Neighbor { + bool hasCandNext = invalid_ns_->has_next(); + bool hasResNext = valid_ns_->has_next(); + + if (hasCandNext && hasResNext) { + return invalid_ns_->cur().distance < valid_ns_->cur().distance ? invalid_ns_->pop() : valid_ns_->pop(); + } + if (hasCandNext != hasResNext) { + return hasCandNext ? invalid_ns_->pop() : valid_ns_->pop(); + } + return {0, 0, Neighbor::kValid}; + } + + std::unique_ptr> valid_ns_ = nullptr; + std::unique_ptr> invalid_ns_ = nullptr; +}; + +static inline int +InsertIntoPool(Neighbor* addr, intptr_t size, Neighbor nn) { + intptr_t p = std::lower_bound(addr, addr + size, nn) - addr; + std::memmove(addr + p + 1, addr + p, (size - p) * sizeof(Neighbor)); + addr[p] = nn; + return p; +} + +} +} +} diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/utils/Bitset.h b/thirdparty/faiss/faiss/cppcontrib/knowhere/utils/Bitset.h new file mode 100644 index 000000000..c739dbf89 --- /dev/null +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/utils/Bitset.h @@ -0,0 +1,115 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +struct Bitset final { + struct Proxy { + uint8_t& element; + uint8_t mask; + + inline Proxy(uint8_t& _element, const size_t _shift) : + element{_element}, mask(uint8_t(1) << _shift) {} + + inline operator bool() const { return ((element & mask) != 0); } + + inline Proxy& operator=(const bool value) { + if (value) { set(); } else { reset(); } + return *this; + } + + inline void set() { + element |= mask; + } + + inline void reset() { + element &= ~mask; + } + }; + + inline Bitset() {} + + // create an uncleared bitset + inline static Bitset create_uninitialized(const size_t initial_size) { + Bitset bitset; + + const size_t nbytes = (initial_size + 7) / 8; + + bitset.bits = std::make_unique(nbytes); + bitset.size = initial_size; + + return bitset; + } + + // create an initialized bitset + inline static Bitset create_cleared(const size_t initial_size) { + Bitset bitset = create_uninitialized(initial_size); + bitset.clear(); + + return bitset; + } + + Bitset(const Bitset&) = delete; + Bitset(Bitset&&) = default; + Bitset& operator=(const Bitset&) = delete; + Bitset& operator=(Bitset&&) = default; + + inline bool get(const size_t index) const { + return (bits[index >> 3] & (0x1 << (index & 0x7))); + } + + inline void set(const size_t index) { + bits[index >> 3] |= uint8_t(0x1 << (index & 0x7)); + } + + inline void reset(const size_t index) { + bits[index >> 3] &= (~uint8_t(0x1 << (index & 0x7))); + } + + inline const uint8_t* get_ptr(const size_t index) const { + return bits.get() + index / 8; + } + + inline uint8_t* get_ptr(const size_t index) { + return bits.get() + index / 8; + } + + inline void clear() { + const size_t nbytes = (size + 7) / 8; + std::memset(bits.get(), 0, nbytes); + } + + inline Proxy operator[](const size_t bit_idx) { + uint8_t& element = bits[bit_idx / 8]; + const size_t shift = bit_idx & 7; + return Proxy{element, shift}; + } + + inline bool operator[](const size_t bit_idx) const { + return get(bit_idx); + } + + std::unique_ptr bits; + size_t size = 0; +}; + +} +} +} diff --git a/thirdparty/faiss/faiss/impl/index_read.cpp b/thirdparty/faiss/faiss/impl/index_read.cpp index 8fe5ad8e4..64c2b198e 100644 --- a/thirdparty/faiss/faiss/impl/index_read.cpp +++ b/thirdparty/faiss/faiss/impl/index_read.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -685,6 +686,20 @@ Index* read_index(IOReader* f, int io_flags) { if (h == fourcc("null")) { // denotes a missing index, useful for some cases return nullptr; + } else if (h == fourcc("IxF9")) { + IndexFlatCosine* idxf = new IndexFlatCosine(); + read_index_header(idxf, f); + idxf->code_size = idxf->d * sizeof(float); + READXBVECTOR(idxf->codes); + READVECTOR(idxf->code_norms); + + // reconstruct inverse norms + idxf->inverse_norms_storage = L2NormsStorage::from_l2_norms(idxf->code_norms); + + FAISS_THROW_IF_NOT( + idxf->codes.size() == idxf->ntotal * idxf->code_size); + // leak! + idx = idxf; } else if ( h == fourcc("IxFI") || h == fourcc("IxF2") || h == fourcc("IxFl")) { IndexFlat* idxf; @@ -1139,7 +1154,7 @@ Index* read_index(IOReader* f, int io_flags) { idx = idxp; } else if ( h == fourcc("IHNf") || h == fourcc("IHNp") || h == fourcc("IHNs") || - h == fourcc("IHN2") || h == fourcc("IHNc")) { + h == fourcc("IHN2") || h == fourcc("IHNc") || h == fourcc("IHN9")) { IndexHNSW* idxhnsw = nullptr; if (h == fourcc("IHNf")) idxhnsw = new IndexHNSWFlat(); @@ -1151,6 +1166,8 @@ Index* read_index(IOReader* f, int io_flags) { idxhnsw = new IndexHNSW2Level(); if (h == fourcc("IHNc")) idxhnsw = new IndexHNSWCagra(); + if (h == fourcc("IHN9")) + idxhnsw = new IndexHNSWFlatCosine(); read_index_header(idxhnsw, f); if (h == fourcc("IHNc")) { READ1(idxhnsw->keep_max_size_level0); diff --git a/thirdparty/faiss/faiss/impl/index_write.cpp b/thirdparty/faiss/faiss/impl/index_write.cpp index d57c6edbf..99ea01d0a 100644 --- a/thirdparty/faiss/faiss/impl/index_write.cpp +++ b/thirdparty/faiss/faiss/impl/index_write.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -559,6 +560,14 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { // eg. for a storage component of HNSW that is set to nullptr uint32_t h = fourcc("null"); WRITE1(h); + } else if (const IndexFlatCosine* idxf = dynamic_cast(idx)) { + uint32_t h = fourcc("IxF9"); + WRITE1(h); + write_index_header(idx, f); + WRITEXBVECTOR(idxf->codes); + // we're storing real l2 norms, because of + // backward compatibility issues. + WRITEVECTOR(idxf->inverse_norms_storage.as_l2_norms()); } else if (const IndexFlat* idxf = dynamic_cast(idx)) { uint32_t h = fourcc(idxf->metric_type == METRIC_INNER_PRODUCT ? "IxFI" @@ -948,6 +957,7 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { : dynamic_cast(idx) ? fourcc("IHNs") : dynamic_cast(idx) ? fourcc("IHN2") : dynamic_cast(idx) ? fourcc("IHNc") + : dynamic_cast(idx) ? fourcc("IHN9") : 0; FAISS_THROW_IF_NOT(h != 0); WRITE1(h); From d3fa0dfd2f21e2c4c0263f09b6d9b2f843569e9e Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Thu, 1 Aug 2024 16:08:52 -0400 Subject: [PATCH 02/21] FAISS_HNSW_SQ Signed-off-by: Alexandr Guzhva --- include/knowhere/comp/index_param.h | 1 + include/knowhere/index/index_table.h | 2 + src/index/hnsw/faiss_hnsw.cc | 159 +++++++++++++++++++- src/index/hnsw/faiss_hnsw_config.h | 91 ++++++++++- thirdparty/faiss/faiss/IndexCosine.cpp | 61 ++++++++ thirdparty/faiss/faiss/IndexCosine.h | 26 ++++ thirdparty/faiss/faiss/impl/index_read.cpp | 16 +- thirdparty/faiss/faiss/impl/index_write.cpp | 23 ++- 8 files changed, 370 insertions(+), 9 deletions(-) diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index 612a6feb9..2f71ac84a 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -54,6 +54,7 @@ constexpr const char* INDEX_HNSW_SQ8_REFINE = "HNSW_SQ8_REFINE"; constexpr const char* INDEX_DISKANN = "DISKANN"; constexpr const char* INDEX_FAISS_HNSW_FLAT = "FAISS_HNSW_FLAT"; +constexpr const char* INDEX_FAISS_HNSW_SQ = "FAISS_HNSW_SQ"; constexpr const char* INDEX_SPARSE_INVERTED_INDEX = "SPARSE_INVERTED_INDEX"; constexpr const char* INDEX_SPARSE_WAND = "SPARSE_WAND"; diff --git a/include/knowhere/index/index_table.h b/include/knowhere/index/index_table.h index 3c29973ba..464804929 100644 --- a/include/knowhere/index/index_table.h +++ b/include/knowhere/index/index_table.h @@ -68,6 +68,8 @@ static std::set> legal_knowhere_index = { {IndexEnum::INDEX_HNSW_SQ8_REFINE, VecType::VECTOR_BFLOAT16}, // faiss hnsw {IndexEnum::INDEX_FAISS_HNSW_FLAT, VecType::VECTOR_FLOAT}, + + {IndexEnum::INDEX_FAISS_HNSW_SQ, VecType::VECTOR_FLOAT}, // diskann {IndexEnum::INDEX_DISKANN, VecType::VECTOR_FLOAT}, {IndexEnum::INDEX_DISKANN, VecType::VECTOR_FLOAT16}, diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index c42269ff7..07fe85c9f 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -423,7 +423,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { // yes, wrap both base and refine index knowhere::IndexWrapperCosine cosine_wrapper( index_refine->refine_index, - dynamic_cast(index_hnsw)->get_inverse_l2_norms()); + dynamic_cast(index_hnsw->storage)->get_inverse_l2_norms()); // create a temporary refine index which does not own faiss::IndexRefine tmp_refine(base_wrapper, &cosine_wrapper); @@ -625,6 +625,163 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode { } }; +// +template +class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { + public: + BaseFaissRegularIndexHNSWSQNode(const int32_t& version, const Object& object) : + BaseFaissRegularIndexHNSWNode(version, object) {} + + std::unique_ptr + CreateConfig() const override { + return std::make_unique(); + } + + std::string + Type() const override { + return knowhere::IndexEnum::INDEX_FAISS_HNSW_SQ; + } + + protected: + Status TrainInternal(const DataSetPtr dataset, const Config& cfg) override { + // number of rows + auto rows = dataset->GetRows(); + // dimensionality of the data + auto dim = dataset->GetDim(); + // data + auto data = dataset->GetTensor(); + + // config + auto hnsw_cfg = static_cast(cfg); + + auto metric = Str2FaissMetricType(hnsw_cfg.metric_type.value()); + if (!metric.has_value()) { + LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << hnsw_cfg.metric_type.value(); + return Status::invalid_metric_type; + } + + // parse a ScalarQuantizer type + auto sq_type = get_sq_quantizer_type(hnsw_cfg.sq_type.value()); + if (!sq_type.has_value()) { + LOG_KNOWHERE_ERROR_ << "Invalid scalar quantizer type: " << hnsw_cfg.sq_type.value(); + return Status::invalid_args; + } + + // create an index + const bool is_cosine = IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE); + + std::unique_ptr hnsw_index; + if (is_cosine) { + hnsw_index = std::make_unique( + dim, sq_type.value(), hnsw_cfg.M.value()); + } else { + hnsw_index = std::make_unique( + dim, sq_type.value(), hnsw_cfg.M.value(), metric.value()); + } + + hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); + + // should refine be used? + std::unique_ptr final_index; + if (hnsw_cfg.refine.value_or(false)) { + // yes + + // grab a type of a refine index + bool is_fp32_flat = false; + if (!hnsw_cfg.refine_type.has_value()) { + is_fp32_flat = true; + } else { + // todo: tolower + if (hnsw_cfg.refine_type.value() == "FP32" || hnsw_cfg.refine_type.value() == "FLAT") { + is_fp32_flat = true; + } else { + // parse + auto refine_sq_type = get_sq_quantizer_type(hnsw_cfg.refine_type.value()); + if (!refine_sq_type.has_value()) { + LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << hnsw_cfg.refine_type.value(); + return Status::invalid_args; + } + + // recognized one, it is not fp32 flat + is_fp32_flat = false; + } + } + + // either build flat or sq + if (is_fp32_flat) { + // build IndexFlat as a refine + auto refine_index = std::make_unique(hnsw_index.get()); + + // let refine_index to own everything + refine_index->own_fields = true; + hnsw_index.release(); + + // reassign + final_index = std::move(refine_index); + } else { + // being IndexScalarQuantizer as a refine + auto refine_sq_type = get_sq_quantizer_type(hnsw_cfg.refine_type.value()); + + // a redundant check + if (!refine_sq_type.has_value()) { + LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << hnsw_cfg.refine_type.value(); + return Status::invalid_args; + } + + // create an sq + auto sq_refine = std::make_unique( + dim, refine_sq_type.value(), metric.value() + ); + + auto refine_index = std::make_unique(hnsw_index.get(), sq_refine.get()); + + // let refine_index to own everything + refine_index->own_refine_index = true; + refine_index->own_fields = true; + hnsw_index.release(); + sq_refine.release(); + + // reassign + final_index = std::move(refine_index); + } + } else { + // no refine + + // reassign + final_index = std::move(hnsw_index); + } + + // train + final_index->train(rows, (const float*)data); + + // done + index = std::move(final_index); + return Status::success; + } + +private: + expected + static get_sq_quantizer_type(const std::string& sq_type) { + std::map sq_types = { + {"SQ6", faiss::ScalarQuantizer::QT_6bit}, + {"SQ8", faiss::ScalarQuantizer::QT_8bit}, + {"FP16", faiss::ScalarQuantizer::QT_fp16}, + {"BF16", faiss::ScalarQuantizer::QT_bf16} + }; + + // todo: tolower + auto itr = sq_types.find(sq_type); + if (itr == sq_types.cend()) { + return expected::Err( + Status::invalid_args, fmt::format("invalid scalar quantizer type ({})", sq_type)); + } + + return itr->second; + } +}; + + KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_FLAT, BaseFaissRegularIndexHNSWFlatNode, fp32); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_SQ, BaseFaissRegularIndexHNSWSQNode, fp32); } // namespace knowhere diff --git a/src/index/hnsw/faiss_hnsw_config.h b/src/index/hnsw/faiss_hnsw_config.h index f36450460..637e6280d 100644 --- a/src/index/hnsw/faiss_hnsw_config.h +++ b/src/index/hnsw/faiss_hnsw_config.h @@ -39,6 +39,8 @@ class FaissHnswConfig : public BaseConfig { CFG_BOOL refine; // undefined value leads to a search without a refine CFG_FLOAT refine_k; + // type of refine + CFG_STRING refine_type; KNOHWERE_DECLARE_CONFIG(FaissHnswConfig) { KNOWHERE_CONFIG_DECLARE_FIELD(M).description("hnsw M").set_default(30).set_range(2, 2048).for_train(); @@ -78,6 +80,10 @@ class FaissHnswConfig : public BaseConfig { .allow_empty_without_default() .set_range(1, std::numeric_limits::max()) .for_search(); + KNOWHERE_CONFIG_DECLARE_FIELD(refine_type) + .description("the type of a refine index") + .allow_empty_without_default() + .for_train(); } Status @@ -107,6 +113,23 @@ class FaissHnswConfig : public BaseConfig { } return Status::success; } + + protected: + bool WhetherAcceptableRefineType(const std::string& refine_type) { + // 'flat' is identical to 'fp32' + std::vector allowed_list = { + "SQ8", "FP16", "BF16", "FP32", "FLAT"}; + + // todo: tolower() + + for (const auto& allowed : allowed_list) { + if (refine_type == allowed) { + return true; + } + } + + return false; + } }; class FaissHnswFlatConfig : public FaissHnswConfig { @@ -121,14 +144,80 @@ class FaissHnswFlatConfig : public FaissHnswConfig { // check our parameters if (param_type == PARAM_TYPE::TRAIN) { + // prohibit refine if (refine.value_or(false)) { - *err_msg = "refine is not currently supported for this index"; + *err_msg = "refine is not supported for this index"; + LOG_KNOWHERE_ERROR_ << *err_msg; + return Status::invalid_value_in_json; + } + + if (refine_type.has_value()) { + *err_msg = "refine is not supported for this index"; + LOG_KNOWHERE_ERROR_ << *err_msg; + return Status::invalid_value_in_json; + } + } + return Status::success; + } +}; + +class FaissHnswSqConfig : public FaissHnswConfig { +public: + // user can use quant_type to control quantizer type. + // we have fp16, bf16, etc, so '8', '4' and '6' is insufficient + CFG_STRING sq_type; + KNOHWERE_DECLARE_CONFIG(FaissHnswSqConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(sq_type) + .set_default("SQ8") + .description("scalar quantizer type") + .for_train(); + }; + + Status + CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override { + // check the base class + const auto base_status = FaissHnswConfig::CheckAndAdjust(param_type, err_msg); + if (base_status != Status::success) { + return base_status; + } + + // check our parameters + if (param_type == PARAM_TYPE::TRAIN) { + auto sq_type_v = sq_type.value(); + if (!WhetherAcceptableQuantType(sq_type_v)) { + *err_msg = "invalid scalar quantizer type"; LOG_KNOWHERE_ERROR_ << *err_msg; return Status::invalid_value_in_json; } + + // check refine + if (refine_type.has_value()) { + if (!WhetherAcceptableRefineType(refine_type.value())) { + *err_msg = "invalid refine type type"; + LOG_KNOWHERE_ERROR_ << *err_msg; + return Status::invalid_value_in_json; + } + } } return Status::success; } + +private: + bool WhetherAcceptableQuantType(const std::string& sq_type) { + // todo: add more + std::vector allowed_list = { + "SQ6", "SQ8", "FP16", "BF16"}; + + // todo: tolower() + + for (const auto& allowed : allowed_list) { + if (sq_type == allowed) { + return true; + } + } + + return false; + } }; } // namespace knowhere diff --git a/thirdparty/faiss/faiss/IndexCosine.cpp b/thirdparty/faiss/faiss/IndexCosine.cpp index df4d912cc..7454b4974 100644 --- a/thirdparty/faiss/faiss/IndexCosine.cpp +++ b/thirdparty/faiss/faiss/IndexCosine.cpp @@ -300,6 +300,52 @@ FlatCodesDistanceComputer* IndexFlatCosine::get_FlatCodesDistanceComputer() cons } +////////////////////////////////////////////////////////////////////////////////// + +IndexScalarQuantizerCosine::IndexScalarQuantizerCosine( + int d, + ScalarQuantizer::QuantizerType qtype) + : IndexScalarQuantizer(d, qtype, MetricType::METRIC_INNER_PRODUCT) { + is_cosine = true; +} + +IndexScalarQuantizerCosine::IndexScalarQuantizerCosine() : IndexScalarQuantizer() { + metric_type = MetricType::METRIC_INNER_PRODUCT; + is_cosine = true; +} + +void IndexScalarQuantizerCosine::add(idx_t n, const float* x) { + FAISS_THROW_IF_NOT(is_trained); + if (n == 0) { + return; + } + + // todo aguzhva: + // it is a tricky situation at this moment, because IndexScalarQuantizerCosine + // contains duplicate norms (one in IndexFlatCodes and one in HasInverseL2Norms). + // Norms in IndexFlatCodes are going to be removed in the future. + IndexScalarQuantizer::add(n, x); + inverse_norms_storage.add(x, n, d); +} + +void IndexScalarQuantizerCosine::reset() { + IndexScalarQuantizer::reset(); + inverse_norms_storage.reset(); +} + +const float* IndexScalarQuantizerCosine::get_inverse_l2_norms() const { + return inverse_norms_storage.inverse_l2_norms.data(); +} + +DistanceComputer* IndexScalarQuantizerCosine::get_distance_computer() const { + return new WithCosineNormDistanceComputer( + this->get_inverse_l2_norms(), + this->d, + std::unique_ptr(IndexScalarQuantizer::get_FlatCodesDistanceComputer()) + ); +} + + ////////////////////////////////////////////////////////////////////////////////// // @@ -314,6 +360,21 @@ IndexHNSWFlatCosine::IndexHNSWFlatCosine(int d, int M) : is_trained = true; } +////////////////////////////////////////////////////////////////////////////////// + +// +IndexHNSWSQCosine::IndexHNSWSQCosine() = default; + +IndexHNSWSQCosine::IndexHNSWSQCosine( + int d, + ScalarQuantizer::QuantizerType qtype, + int M) : + IndexHNSW(new IndexScalarQuantizerCosine(d, qtype), M) +{ + is_trained = this->storage->is_trained; + own_fields = true; +} + } diff --git a/thirdparty/faiss/faiss/IndexCosine.h b/thirdparty/faiss/faiss/IndexCosine.h index 0b952fd3d..cd283f795 100644 --- a/thirdparty/faiss/faiss/IndexCosine.h +++ b/thirdparty/faiss/faiss/IndexCosine.h @@ -99,12 +99,38 @@ struct IndexFlatCosine : IndexFlat, HasInverseL2Norms { const float* get_inverse_l2_norms() const override; }; +// +struct IndexScalarQuantizerCosine : IndexScalarQuantizer, HasInverseL2Norms { + L2NormsStorage inverse_norms_storage; + + IndexScalarQuantizerCosine( + int d, + ScalarQuantizer::QuantizerType qtype); + + IndexScalarQuantizerCosine(); + + void add(idx_t n, const float* x) override; + void reset() override; + + DistanceComputer* get_distance_computer() const override; + + const float* get_inverse_l2_norms() const override; +}; + // struct IndexHNSWFlatCosine : IndexHNSW { IndexHNSWFlatCosine(); IndexHNSWFlatCosine(int d, int M); }; +// +struct IndexHNSWSQCosine : IndexHNSW { + IndexHNSWSQCosine(); + IndexHNSWSQCosine( + int d, + ScalarQuantizer::QuantizerType qtype, + int M); +}; } diff --git a/thirdparty/faiss/faiss/impl/index_read.cpp b/thirdparty/faiss/faiss/impl/index_read.cpp index 64c2b198e..b980f2859 100644 --- a/thirdparty/faiss/faiss/impl/index_read.cpp +++ b/thirdparty/faiss/faiss/impl/index_read.cpp @@ -973,6 +973,17 @@ Index* read_index(IOReader* f, int io_flags) { } read_InvertedLists(ivfl, f, io_flags); idx = ivfl; + } else if (h == fourcc("IxS8")) { + IndexScalarQuantizerCosine* idxs = new IndexScalarQuantizerCosine(); + read_index_header(idxs, f); + read_ScalarQuantizer(&idxs->sq, f); + READVECTOR(idxs->codes); + idxs->code_size = idxs->sq.code_size; + + // reconstruct inverse norms + READVECTOR(idxs->inverse_norms_storage.inverse_l2_norms); + + idx = idxs; } else if (h == fourcc("IxSQ")) { IndexScalarQuantizer* idxs = new IndexScalarQuantizer(); read_index_header(idxs, f); @@ -1154,7 +1165,8 @@ Index* read_index(IOReader* f, int io_flags) { idx = idxp; } else if ( h == fourcc("IHNf") || h == fourcc("IHNp") || h == fourcc("IHNs") || - h == fourcc("IHN2") || h == fourcc("IHNc") || h == fourcc("IHN9")) { + h == fourcc("IHN2") || h == fourcc("IHNc") || h == fourcc("IHN9") || + h == fourcc("IHN8")) { IndexHNSW* idxhnsw = nullptr; if (h == fourcc("IHNf")) idxhnsw = new IndexHNSWFlat(); @@ -1168,6 +1180,8 @@ Index* read_index(IOReader* f, int io_flags) { idxhnsw = new IndexHNSWCagra(); if (h == fourcc("IHN9")) idxhnsw = new IndexHNSWFlatCosine(); + if (h == fourcc("IHN8")) + idxhnsw = new IndexHNSWSQCosine(); read_index_header(idxhnsw, f); if (h == fourcc("IHNc")) { READ1(idxhnsw->keep_max_size_level0); diff --git a/thirdparty/faiss/faiss/impl/index_write.cpp b/thirdparty/faiss/faiss/impl/index_write.cpp index 99ea01d0a..b465e4bce 100644 --- a/thirdparty/faiss/faiss/impl/index_write.cpp +++ b/thirdparty/faiss/faiss/impl/index_write.cpp @@ -771,6 +771,16 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { WRITE1(idxp_2->code_size_2); WRITE1(idxp_2->code_size); WRITEVECTOR(idxp_2->codes); + } else if ( + const IndexScalarQuantizerCosine* idxs = + dynamic_cast(idx)) { + uint32_t h = fourcc("IxS8"); + WRITE1(h); + write_index_header(idx, f); + write_ScalarQuantizer(&idxs->sq, f); + WRITEVECTOR(idxs->codes); + // inverse norms + WRITEVECTOR(idxs->inverse_norms_storage.inverse_l2_norms); } else if ( const IndexScalarQuantizer* idxs = dynamic_cast(idx)) { @@ -952,13 +962,14 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { write_index(idxmap->index, f); WRITEVECTOR(idxmap->id_map); } else if (const IndexHNSW* idxhnsw = dynamic_cast(idx)) { - uint32_t h = dynamic_cast(idx) ? fourcc("IHNf") - : dynamic_cast(idx) ? fourcc("IHNp") - : dynamic_cast(idx) ? fourcc("IHNs") - : dynamic_cast(idx) ? fourcc("IHN2") - : dynamic_cast(idx) ? fourcc("IHNc") + uint32_t h = dynamic_cast(idx) ? fourcc("IHNf") + : dynamic_cast(idx) ? fourcc("IHNp") + : dynamic_cast(idx) ? fourcc("IHNs") + : dynamic_cast(idx) ? fourcc("IHN2") + : dynamic_cast(idx) ? fourcc("IHNc") : dynamic_cast(idx) ? fourcc("IHN9") - : 0; + : dynamic_cast(idx) ? fourcc("IHN8") + : 0; FAISS_THROW_IF_NOT(h != 0); WRITE1(h); write_index_header(idxhnsw, f); From d7f69527f4d2be0a2f051cfac69df15314db3ba0 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Thu, 1 Aug 2024 18:39:35 -0400 Subject: [PATCH 03/21] FAISS_HNSW_PQ Signed-off-by: Alexandr Guzhva --- include/knowhere/comp/index_param.h | 4 + src/index/hnsw/faiss_hnsw.cc | 365 ++++++++++++++++---- src/index/hnsw/faiss_hnsw_config.h | 21 ++ thirdparty/faiss/faiss/IndexCosine.cpp | 61 ++++ thirdparty/faiss/faiss/IndexCosine.h | 34 ++ thirdparty/faiss/faiss/impl/index_read.cpp | 21 +- thirdparty/faiss/faiss/impl/index_write.cpp | 13 + 7 files changed, 443 insertions(+), 76 deletions(-) diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index 2f71ac84a..c89a0ab5b 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -55,6 +55,7 @@ constexpr const char* INDEX_DISKANN = "DISKANN"; constexpr const char* INDEX_FAISS_HNSW_FLAT = "FAISS_HNSW_FLAT"; constexpr const char* INDEX_FAISS_HNSW_SQ = "FAISS_HNSW_SQ"; +constexpr const char* INDEX_FAISS_HNSW_PQ = "FAISS_HNSW_PQ"; constexpr const char* INDEX_SPARSE_INVERTED_INDEX = "SPARSE_INVERTED_INDEX"; constexpr const char* INDEX_SPARSE_WAND = "SPARSE_WAND"; @@ -153,6 +154,9 @@ constexpr const char* HNSW_M = "M"; constexpr const char* EF = "ef"; constexpr const char* OVERVIEW_LEVELS = "overview_levels"; +// FAISS additional Params +constexpr const char* SQ_TYPE = "sq_type"; // for IVF_SQ and HNSW_SQ + // Sparse Params constexpr const char* DROP_RATIO_BUILD = "drop_ratio_build"; constexpr const char* DROP_RATIO_SEARCH = "drop_ratio_search"; diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 07fe85c9f..40fdc2f91 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -539,7 +539,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { Status AddInternal(const DataSetPtr dataset, const Config&) override { - if (this->index == nullptr) { + if (index == nullptr) { LOG_KNOWHERE_ERROR_ << "Can not add data to an empty index."; return Status::empty_index; } @@ -547,7 +547,9 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { auto data = dataset->GetTensor(); auto rows = dataset->GetRows(); try { - this->index->add(rows, reinterpret_cast(data)); + LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; + + index->add(rows, reinterpret_cast(data)); } catch (const std::exception& e) { LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return Status::faiss_inner_error; @@ -617,6 +619,8 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode { hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); // train + LOG_KNOWHERE_INFO_ << "Training HNSW Index"; + hnsw_index->train(rows, (const float*)data); // done @@ -625,6 +629,112 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode { } }; + +namespace { + +// a supporting function +expected +get_sq_quantizer_type(const std::string& sq_type) { + std::map sq_types = { + {"SQ6", faiss::ScalarQuantizer::QT_6bit}, + {"SQ8", faiss::ScalarQuantizer::QT_8bit}, + {"FP16", faiss::ScalarQuantizer::QT_fp16}, + {"BF16", faiss::ScalarQuantizer::QT_bf16} + }; + + // todo: tolower + auto itr = sq_types.find(sq_type); + if (itr == sq_types.cend()) { + return expected::Err( + Status::invalid_args, fmt::format("invalid scalar quantizer type ({})", sq_type)); + } + + return itr->second; +} + +expected +is_flat_refine(const std::optional& refine_type) { + // grab a type of a refine index + if (!refine_type.has_value()) { + return true; + }; + + // todo: tolower + if (refine_type.value() == "FP32" || refine_type.value() == "FLAT") { + return true; + }; + + // parse + auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); + if (!refine_sq_type.has_value()) { + LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << refine_type.value(); + return expected::Err( + Status::invalid_args, fmt::format("invalid refine type ({})", refine_type.value())); + } + + return false; +} + +// pick a refine index +expected> +pick_refine_index(const std::optional& refine_type, std::unique_ptr&& hnsw_index) { + // yes + + // grab a type of a refine index + expected is_fp32_flat = is_flat_refine(refine_type); + if (!is_fp32_flat.has_value()) { + return expected>::Err( + Status::invalid_args, ""); + } + + const bool is_fp32_flat_v = is_fp32_flat.value(); + + std::unique_ptr local_hnsw_index = std::move(hnsw_index); + + // either build flat or sq + if (is_fp32_flat_v) { + // build IndexFlat as a refine + auto refine_index = std::make_unique(local_hnsw_index.get()); + + // let refine_index to own everything + refine_index->own_fields = true; + local_hnsw_index.release(); + + // reassign + return refine_index; + } else { + // being IndexScalarQuantizer as a refine + auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); + + // a redundant check + if (!refine_sq_type.has_value()) { + LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << refine_type.value(); + return expected>::Err( + Status::invalid_args, fmt::format("invalid refine type ({})", refine_type.value())); + } + + // create an sq + auto sq_refine = std::make_unique( + local_hnsw_index->storage->d, + refine_sq_type.value(), + local_hnsw_index->storage->metric_type + ); + + auto refine_index = std::make_unique(local_hnsw_index.get(), sq_refine.get()); + + // let refine_index to own everything + refine_index->own_refine_index = true; + refine_index->own_fields = true; + local_hnsw_index.release(); + sq_refine.release(); + + // reassign + return refine_index; + } +} + +} // namespace + // template class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { @@ -685,103 +795,208 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { std::unique_ptr final_index; if (hnsw_cfg.refine.value_or(false)) { // yes + auto final_index_cnd = pick_refine_index(hnsw_cfg.refine_type, std::move(hnsw_index)); + if (!final_index_cnd.has_value()) { + return Status::invalid_args; + } - // grab a type of a refine index - bool is_fp32_flat = false; - if (!hnsw_cfg.refine_type.has_value()) { - is_fp32_flat = true; - } else { - // todo: tolower - if (hnsw_cfg.refine_type.value() == "FP32" || hnsw_cfg.refine_type.value() == "FLAT") { - is_fp32_flat = true; - } else { - // parse - auto refine_sq_type = get_sq_quantizer_type(hnsw_cfg.refine_type.value()); - if (!refine_sq_type.has_value()) { - LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << hnsw_cfg.refine_type.value(); - return Status::invalid_args; - } + // assign + final_index = std::move(final_index_cnd.value()); + } else { + // no refine - // recognized one, it is not fp32 flat - is_fp32_flat = false; - } - } + // assign + final_index = std::move(hnsw_index); + } - // either build flat or sq - if (is_fp32_flat) { - // build IndexFlat as a refine - auto refine_index = std::make_unique(hnsw_index.get()); - - // let refine_index to own everything - refine_index->own_fields = true; - hnsw_index.release(); - - // reassign - final_index = std::move(refine_index); - } else { - // being IndexScalarQuantizer as a refine - auto refine_sq_type = get_sq_quantizer_type(hnsw_cfg.refine_type.value()); - - // a redundant check - if (!refine_sq_type.has_value()) { - LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << hnsw_cfg.refine_type.value(); - return Status::invalid_args; - } - - // create an sq - auto sq_refine = std::make_unique( - dim, refine_sq_type.value(), metric.value() - ); - - auto refine_index = std::make_unique(hnsw_index.get(), sq_refine.get()); - - // let refine_index to own everything - refine_index->own_refine_index = true; - refine_index->own_fields = true; - hnsw_index.release(); - sq_refine.release(); - - // reassign - final_index = std::move(refine_index); - } + // train + LOG_KNOWHERE_INFO_ << "Training HNSW Index"; + + final_index->train(rows, (const float*)data); + + // done + index = std::move(final_index); + return Status::success; + } +}; + + +// this index trains PQ and HNSW+FLAT separately, then constructs HNSW+PQ +template +class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { + public: + BaseFaissRegularIndexHNSWPQNode(const int32_t& version, const Object& object) : + BaseFaissRegularIndexHNSWNode(version, object) {} + + std::unique_ptr + CreateConfig() const override { + return std::make_unique(); + } + + std::string + Type() const override { + return knowhere::IndexEnum::INDEX_FAISS_HNSW_PQ; + } + + protected: + std::unique_ptr tmp_index_pq; + + Status TrainInternal(const DataSetPtr dataset, const Config& cfg) override { + // number of rows + auto rows = dataset->GetRows(); + // dimensionality of the data + auto dim = dataset->GetDim(); + // data + auto data = dataset->GetTensor(); + + // config + auto hnsw_cfg = static_cast(cfg); + + auto metric = Str2FaissMetricType(hnsw_cfg.metric_type.value()); + if (!metric.has_value()) { + LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << hnsw_cfg.metric_type.value(); + return Status::invalid_metric_type; + } + + // create an index + const bool is_cosine = IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE); + + // HNSW + PQ index yields BAD recall somewhy. + // Let's build HNSW+FLAT index, then replace FLAT with PQ + + std::unique_ptr hnsw_index; + if (is_cosine) { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value()); + } else { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); + } + + hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); + + // pq + std::unique_ptr pq_index; + if (is_cosine) { + pq_index = std::make_unique(dim, hnsw_cfg.m.value(), hnsw_cfg.nbits.value()); + } else { + pq_index = std::make_unique(dim, hnsw_cfg.m.value(), hnsw_cfg.nbits.value(), metric.value()); + } + + + // should refine be used? + std::unique_ptr final_index; + if (hnsw_cfg.refine.value_or(false)) { + // yes + auto final_index_cnd = pick_refine_index(hnsw_cfg.refine_type, std::move(hnsw_index)); + if (!final_index_cnd.has_value()) { + return Status::invalid_args; + } + + // assign + final_index = std::move(final_index_cnd.value()); } else { // no refine - // reassign + // assign final_index = std::move(hnsw_index); } - // train + // train hnswflat + LOG_KNOWHERE_INFO_ << "Training HNSW Index"; + final_index->train(rows, (const float*)data); + // train pq + LOG_KNOWHERE_INFO_ << "Training PQ Index"; + + pq_index->train(rows, (const float*)data); + pq_index->pq.compute_sdc_table(); + // done index = std::move(final_index); + tmp_index_pq = std::move(pq_index); return Status::success; } -private: - expected - static get_sq_quantizer_type(const std::string& sq_type) { - std::map sq_types = { - {"SQ6", faiss::ScalarQuantizer::QT_6bit}, - {"SQ8", faiss::ScalarQuantizer::QT_8bit}, - {"FP16", faiss::ScalarQuantizer::QT_fp16}, - {"BF16", faiss::ScalarQuantizer::QT_bf16} - }; + Status + AddInternal(const DataSetPtr dataset, const Config&) override { + if (this->index == nullptr) { + LOG_KNOWHERE_ERROR_ << "Can not add data to an empty index."; + return Status::empty_index; + } + + auto data = dataset->GetTensor(); + auto rows = dataset->GetRows(); + try { + // hnsw + LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; + + index->add(rows, reinterpret_cast(data)); + + // pq + LOG_KNOWHERE_INFO_ << "Adding " << rows << " to PQ Index"; + + tmp_index_pq->add(rows, reinterpret_cast(data)); - // todo: tolower - auto itr = sq_types.find(sq_type); - if (itr == sq_types.cend()) { - return expected::Err( - Status::invalid_args, fmt::format("invalid scalar quantizer type ({})", sq_type)); + // we're done. + // throw away flat and replace it with pq + + // check if we have a refine available. + faiss::IndexHNSW* index_hnsw = nullptr; + + faiss::IndexRefine* const index_refine = + dynamic_cast(index.get()); + + if (index_refine != nullptr) { + index_hnsw = dynamic_cast(index_refine->base_index); + } else { + index_hnsw = dynamic_cast(index.get()); + } + + // recreate hnswpq + std::unique_ptr index_hnsw_pq; + + if (index_hnsw->storage->is_cosine) { + index_hnsw_pq = std::make_unique(); + } else { + index_hnsw_pq = std::make_unique(); + } + + // C++ slicing + static_cast(*index_hnsw_pq) = + std::move(static_cast(*index_hnsw)); + + // clear out the storage + delete index_hnsw->storage; + index_hnsw->storage = nullptr; + index_hnsw_pq->storage = nullptr; + + // replace storage + index_hnsw_pq->storage = tmp_index_pq.release(); + + // replace if refine + if (index_refine != nullptr) { + delete index_refine->base_index; + index_refine->base_index = index_hnsw_pq.release(); + } else { + index = std::move(index_hnsw_pq); + } + + } catch (const std::exception& e) { + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); + return Status::faiss_inner_error; } - return itr->second; + return Status::success; } + }; + KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_FLAT, BaseFaissRegularIndexHNSWFlatNode, fp32); + KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_SQ, BaseFaissRegularIndexHNSWSQNode, fp32); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PQ, BaseFaissRegularIndexHNSWPQNode, fp32); + } // namespace knowhere diff --git a/src/index/hnsw/faiss_hnsw_config.h b/src/index/hnsw/faiss_hnsw_config.h index 637e6280d..c1fcb76f6 100644 --- a/src/index/hnsw/faiss_hnsw_config.h +++ b/src/index/hnsw/faiss_hnsw_config.h @@ -220,6 +220,27 @@ class FaissHnswSqConfig : public FaissHnswConfig { } }; +class FaissHnswPqConfig : public FaissHnswConfig { +public: + // number of subquantizers + CFG_INT m; + // number of bits per subquantizer + CFG_INT nbits; + + KNOHWERE_DECLARE_CONFIG(FaissHnswPqConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(m) + .description("m") + .set_default(32) + .for_train() + .set_range(1, 65536); + KNOWHERE_CONFIG_DECLARE_FIELD(nbits) + .description("nbits") + .set_default(8) + .for_train() + .set_range(1, 16); + } +}; + } // namespace knowhere #endif /* FAISS_HNSW_CONFIG_H */ diff --git a/thirdparty/faiss/faiss/IndexCosine.cpp b/thirdparty/faiss/faiss/IndexCosine.cpp index 7454b4974..fc12c61bb 100644 --- a/thirdparty/faiss/faiss/IndexCosine.cpp +++ b/thirdparty/faiss/faiss/IndexCosine.cpp @@ -346,6 +346,51 @@ DistanceComputer* IndexScalarQuantizerCosine::get_distance_computer() const { } +////////////////////////////////////////////////////////////////////////////////// + +// +IndexPQCosine::IndexPQCosine(int d, size_t M, size_t nbits) : + IndexPQ(d, M, nbits, MetricType::METRIC_INNER_PRODUCT) { + is_cosine = true; +} + +IndexPQCosine::IndexPQCosine() : IndexPQ() { + metric_type = MetricType::METRIC_INNER_PRODUCT; + is_cosine = true; +} + +void IndexPQCosine::add(idx_t n, const float* x) { + FAISS_THROW_IF_NOT(is_trained); + if (n == 0) { + return; + } + + // todo aguzhva: + // it is a tricky situation at this moment, because IndexPQCosine + // contains duplicate norms (one in IndexFlatCodes and one in HasInverseL2Norms). + // Norms in IndexFlatCodes are going to be removed in the future. + IndexPQ::add(n, x); + inverse_norms_storage.add(x, n, d); +} + +void IndexPQCosine::reset() { + IndexPQ::reset(); + inverse_norms_storage.reset(); +} + +const float* IndexPQCosine::get_inverse_l2_norms() const { + return inverse_norms_storage.inverse_l2_norms.data(); +} + +DistanceComputer* IndexPQCosine::get_distance_computer() const { + return new WithCosineNormDistanceComputer( + this->get_inverse_l2_norms(), + this->d, + std::unique_ptr(IndexPQ::get_FlatCodesDistanceComputer()) + ); +} + + ////////////////////////////////////////////////////////////////////////////////// // @@ -360,6 +405,7 @@ IndexHNSWFlatCosine::IndexHNSWFlatCosine(int d, int M) : is_trained = true; } + ////////////////////////////////////////////////////////////////////////////////// // @@ -376,5 +422,20 @@ IndexHNSWSQCosine::IndexHNSWSQCosine( } +// +IndexHNSWPQCosine::IndexHNSWPQCosine() = default; + +IndexHNSWPQCosine::IndexHNSWPQCosine(int d, size_t pq_M, int M, size_t pq_nbits) : + IndexHNSW(new IndexPQCosine(d, pq_M, pq_nbits), M) +{ + own_fields = true; +} + +void IndexHNSWPQCosine::train(idx_t n, const float* x) { + IndexHNSW::train(n, x); + (dynamic_cast(storage))->pq.compute_sdc_table(); +} + + } diff --git a/thirdparty/faiss/faiss/IndexCosine.h b/thirdparty/faiss/faiss/IndexCosine.h index cd283f795..3bb3eae7a 100644 --- a/thirdparty/faiss/faiss/IndexCosine.h +++ b/thirdparty/faiss/faiss/IndexCosine.h @@ -13,8 +13,13 @@ #pragma once +#include +#include + #include #include +#include +#include #include @@ -117,6 +122,23 @@ struct IndexScalarQuantizerCosine : IndexScalarQuantizer, HasInverseL2Norms { const float* get_inverse_l2_norms() const override; }; +// +struct IndexPQCosine : IndexPQ, HasInverseL2Norms { + L2NormsStorage inverse_norms_storage; + + IndexPQCosine(int d, size_t M, size_t nbits); + + IndexPQCosine(); + + void add(idx_t n, const float* x) override; + void reset() override; + + DistanceComputer* get_distance_computer() const override; + + const float* get_inverse_l2_norms() const override; +}; + + // struct IndexHNSWFlatCosine : IndexHNSW { IndexHNSWFlatCosine(); @@ -132,5 +154,17 @@ struct IndexHNSWSQCosine : IndexHNSW { int M); }; +// +struct IndexHNSWPQCosine : IndexHNSW { + IndexHNSWPQCosine(); + IndexHNSWPQCosine( + int d, + size_t pq_M, + int M, + size_t pq_nbits); + + void train(idx_t n, const float* x) override; +}; + } diff --git a/thirdparty/faiss/faiss/impl/index_read.cpp b/thirdparty/faiss/faiss/impl/index_read.cpp index b980f2859..9889342f5 100644 --- a/thirdparty/faiss/faiss/impl/index_read.cpp +++ b/thirdparty/faiss/faiss/impl/index_read.cpp @@ -752,6 +752,23 @@ Index* read_index(IOReader* f, int io_flags) { FAISS_THROW_IF_NOT( idxl->codes.size() == idxl->ntotal * idxl->code_size); idx = idxl; + } else if (h == fourcc("IxP7")) { + IndexPQCosine* idxp = new IndexPQCosine(); + read_index_header(idxp, f); + read_ProductQuantizer(&idxp->pq, f); + idxp->code_size = idxp->pq.code_size; + READVECTOR(idxp->codes); + READ1(idxp->search_type); + READ1(idxp->encode_signs); + READ1(idxp->polysemous_ht); + // read inverse norms + READVECTOR(idxp->inverse_norms_storage.inverse_l2_norms); + + if (!(io_flags & IO_FLAG_PQ_SKIP_SDC_TABLE)) { + idxp->pq.compute_sdc_table (); + } + + idx = idxp; } else if ( h == fourcc("IxPQ") || h == fourcc("IxPo") || h == fourcc("IxPq")) { // IxPQ and IxPo were merged into the same IndexPQ object @@ -1166,7 +1183,7 @@ Index* read_index(IOReader* f, int io_flags) { } else if ( h == fourcc("IHNf") || h == fourcc("IHNp") || h == fourcc("IHNs") || h == fourcc("IHN2") || h == fourcc("IHNc") || h == fourcc("IHN9") || - h == fourcc("IHN8")) { + h == fourcc("IHN8") || h == fourcc("IHN7")) { IndexHNSW* idxhnsw = nullptr; if (h == fourcc("IHNf")) idxhnsw = new IndexHNSWFlat(); @@ -1182,6 +1199,8 @@ Index* read_index(IOReader* f, int io_flags) { idxhnsw = new IndexHNSWFlatCosine(); if (h == fourcc("IHN8")) idxhnsw = new IndexHNSWSQCosine(); + if (h == fourcc("IHN7")) + idxhnsw = new IndexHNSWPQCosine(); read_index_header(idxhnsw, f); if (h == fourcc("IHNc")) { READ1(idxhnsw->keep_max_size_level0); diff --git a/thirdparty/faiss/faiss/impl/index_write.cpp b/thirdparty/faiss/faiss/impl/index_write.cpp index b465e4bce..d44a223ea 100644 --- a/thirdparty/faiss/faiss/impl/index_write.cpp +++ b/thirdparty/faiss/faiss/impl/index_write.cpp @@ -591,6 +591,18 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { WRITE1(code_size_i); write_VectorTransform(&idxl->rrot, f); WRITEVECTOR(idxl->codes); + } else if (const IndexPQCosine* idxp = dynamic_cast(idx)) { + uint32_t h = fourcc("IxP7"); + WRITE1(h); + write_index_header(idx, f); + write_ProductQuantizer(&idxp->pq, f); + WRITEVECTOR(idxp->codes); + // search params -- maybe not useful to store? + WRITE1(idxp->search_type); + WRITE1(idxp->encode_signs); + WRITE1(idxp->polysemous_ht); + // inverse norms + WRITEVECTOR(idxp->inverse_norms_storage.inverse_l2_norms); } else if (const IndexPQ* idxp = dynamic_cast(idx)) { uint32_t h = fourcc("IxPq"); WRITE1(h); @@ -969,6 +981,7 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { : dynamic_cast(idx) ? fourcc("IHNc") : dynamic_cast(idx) ? fourcc("IHN9") : dynamic_cast(idx) ? fourcc("IHN8") + : dynamic_cast(idx) ? fourcc("IHN7") : 0; FAISS_THROW_IF_NOT(h != 0); WRITE1(h); From 8194ad6303800918a6f0a135fa3cea69e8391a35 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Sat, 3 Aug 2024 08:11:46 -0400 Subject: [PATCH 04/21] Add IndexPRQ Signed-off-by: Alexandr Guzhva --- include/knowhere/comp/index_param.h | 2 + include/knowhere/index/index_table.h | 14 ++ src/index/hnsw/faiss_hnsw.cc | 173 +++++++++++++++++++- src/index/hnsw/faiss_hnsw_config.h | 26 +++ thirdparty/faiss/faiss/IndexCosine.cpp | 76 +++++++++ thirdparty/faiss/faiss/IndexCosine.h | 47 ++++++ thirdparty/faiss/faiss/impl/index_read.cpp | 17 +- thirdparty/faiss/faiss/impl/index_write.cpp | 13 ++ 8 files changed, 366 insertions(+), 2 deletions(-) diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index c89a0ab5b..d9b98e6df 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -56,6 +56,7 @@ constexpr const char* INDEX_DISKANN = "DISKANN"; constexpr const char* INDEX_FAISS_HNSW_FLAT = "FAISS_HNSW_FLAT"; constexpr const char* INDEX_FAISS_HNSW_SQ = "FAISS_HNSW_SQ"; constexpr const char* INDEX_FAISS_HNSW_PQ = "FAISS_HNSW_PQ"; +constexpr const char* INDEX_FAISS_HNSW_PRQ = "FAISS_HNSW_PRQ"; constexpr const char* INDEX_SPARSE_INVERTED_INDEX = "SPARSE_INVERTED_INDEX"; constexpr const char* INDEX_SPARSE_WAND = "SPARSE_WAND"; @@ -156,6 +157,7 @@ constexpr const char* OVERVIEW_LEVELS = "overview_levels"; // FAISS additional Params constexpr const char* SQ_TYPE = "sq_type"; // for IVF_SQ and HNSW_SQ +constexpr const char* PRQ_NUM = "nrq"; // for PRQ, number of redisual quantizers // Sparse Params constexpr const char* DROP_RATIO_BUILD = "drop_ratio_build"; diff --git a/include/knowhere/index/index_table.h b/include/knowhere/index/index_table.h index 464804929..cda09e960 100644 --- a/include/knowhere/index/index_table.h +++ b/include/knowhere/index/index_table.h @@ -70,6 +70,10 @@ static std::set> legal_knowhere_index = { {IndexEnum::INDEX_FAISS_HNSW_FLAT, VecType::VECTOR_FLOAT}, {IndexEnum::INDEX_FAISS_HNSW_SQ, VecType::VECTOR_FLOAT}, + + {IndexEnum::INDEX_FAISS_HNSW_PQ, VecType::VECTOR_FLOAT}, + + {IndexEnum::INDEX_FAISS_HNSW_PRQ, VecType::VECTOR_FLOAT}, // diskann {IndexEnum::INDEX_DISKANN, VecType::VECTOR_FLOAT}, {IndexEnum::INDEX_DISKANN, VecType::VECTOR_FLOAT16}, @@ -120,6 +124,16 @@ static std::set legal_support_mmap_knowhere_index = { IndexEnum::INDEX_HNSW_SQ8_REFINE, IndexEnum::INDEX_HNSW_SQ8_REFINE, IndexEnum::INDEX_HNSW_SQ8_REFINE, + + // faiss hnsw + IndexEnum::INDEX_FAISS_HNSW_FLAT, + + IndexEnum::INDEX_FAISS_HNSW_SQ, + + IndexEnum::INDEX_FAISS_HNSW_PQ, + + IndexEnum::INDEX_FAISS_HNSW_PRQ, + // sparse index IndexEnum::INDEX_SPARSE_INVERTED_INDEX, IndexEnum::INDEX_SPARSE_WAND, diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 40fdc2f91..7f65e31fc 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -988,10 +988,179 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { return Status::success; } - }; +// this index trains PRQ and HNSW+FLAT separately, then constructs HNSW+PRQ +template +class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { + public: + BaseFaissRegularIndexHNSWPRQNode(const int32_t& version, const Object& object) : + BaseFaissRegularIndexHNSWNode(version, object) {} + + std::unique_ptr + CreateConfig() const override { + return std::make_unique(); + } + + std::string + Type() const override { + return knowhere::IndexEnum::INDEX_FAISS_HNSW_PRQ; + } + + protected: + std::unique_ptr tmp_index_prq; + + Status TrainInternal(const DataSetPtr dataset, const Config& cfg) override { + // number of rows + auto rows = dataset->GetRows(); + // dimensionality of the data + auto dim = dataset->GetDim(); + // data + auto data = dataset->GetTensor(); + + // config + auto hnsw_cfg = static_cast(cfg); + + auto metric = Str2FaissMetricType(hnsw_cfg.metric_type.value()); + if (!metric.has_value()) { + LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << hnsw_cfg.metric_type.value(); + return Status::invalid_metric_type; + } + + // create an index + const bool is_cosine = IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE); + + // HNSW + PQ index yields BAD recall somewhy. + // Let's build HNSW+FLAT index, then replace FLAT with PQ + + std::unique_ptr hnsw_index; + if (is_cosine) { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value()); + } else { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); + } + + hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); + + // pq + std::unique_ptr prq_index; + if (is_cosine) { + prq_index = std::make_unique( + dim, hnsw_cfg.nrq.value(), hnsw_cfg.m.value(), hnsw_cfg.nbits.value()); + } else { + prq_index = std::make_unique( + dim, hnsw_cfg.nrq.value(), hnsw_cfg.m.value(), hnsw_cfg.nbits.value(), metric.value()); + } + + + // should refine be used? + std::unique_ptr final_index; + if (hnsw_cfg.refine.value_or(false)) { + // yes + auto final_index_cnd = pick_refine_index(hnsw_cfg.refine_type, std::move(hnsw_index)); + if (!final_index_cnd.has_value()) { + return Status::invalid_args; + } + + // assign + final_index = std::move(final_index_cnd.value()); + } else { + // no refine + + // assign + final_index = std::move(hnsw_index); + } + + // train hnswflat + LOG_KNOWHERE_INFO_ << "Training HNSW Index"; + + final_index->train(rows, (const float*)data); + + // train prq + LOG_KNOWHERE_INFO_ << "Training ProductResidualQuantizer Index"; + + prq_index->train(rows, (const float*)data); + + // done + index = std::move(final_index); + tmp_index_prq = std::move(prq_index); + return Status::success; + } + + Status + AddInternal(const DataSetPtr dataset, const Config&) override { + if (this->index == nullptr) { + LOG_KNOWHERE_ERROR_ << "Can not add data to an empty index."; + return Status::empty_index; + } + + auto data = dataset->GetTensor(); + auto rows = dataset->GetRows(); + try { + // hnsw + LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; + + index->add(rows, reinterpret_cast(data)); + + // pq + LOG_KNOWHERE_INFO_ << "Adding " << rows << " to ProductResidualQuantizer Index"; + + tmp_index_prq->add(rows, reinterpret_cast(data)); + + // we're done. + // throw away flat and replace it with prq + + // check if we have a refine available. + faiss::IndexHNSW* index_hnsw = nullptr; + + faiss::IndexRefine* const index_refine = + dynamic_cast(index.get()); + + if (index_refine != nullptr) { + index_hnsw = dynamic_cast(index_refine->base_index); + } else { + index_hnsw = dynamic_cast(index.get()); + } + + // recreate hnswprq + std::unique_ptr index_hnsw_prq; + + if (index_hnsw->storage->is_cosine) { + index_hnsw_prq = std::make_unique(); + } else { + index_hnsw_prq = std::make_unique(); + } + + // C++ slicing + static_cast(*index_hnsw_prq) = + std::move(static_cast(*index_hnsw)); + + // clear out the storage + delete index_hnsw->storage; + index_hnsw->storage = nullptr; + index_hnsw_prq->storage = nullptr; + + // replace storage + index_hnsw_prq->storage = tmp_index_prq.release(); + + // replace if refine + if (index_refine != nullptr) { + delete index_refine->base_index; + index_refine->base_index = index_hnsw_prq.release(); + } else { + index = std::move(index_hnsw_prq); + } + + } catch (const std::exception& e) { + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); + return Status::faiss_inner_error; + } + + return Status::success; + } +}; + KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_FLAT, BaseFaissRegularIndexHNSWFlatNode, fp32); @@ -999,4 +1168,6 @@ KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_SQ, BaseFaissRegularIndexHNSWSQNode, KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PQ, BaseFaissRegularIndexHNSWPQNode, fp32); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNode, fp32); + } // namespace knowhere diff --git a/src/index/hnsw/faiss_hnsw_config.h b/src/index/hnsw/faiss_hnsw_config.h index c1fcb76f6..77116ce9a 100644 --- a/src/index/hnsw/faiss_hnsw_config.h +++ b/src/index/hnsw/faiss_hnsw_config.h @@ -241,6 +241,32 @@ class FaissHnswPqConfig : public FaissHnswConfig { } }; +class FaissHnswPrqConfig : public FaissHnswConfig { +public: + // number of subquantizer splits + CFG_INT m; + // number of residual quantizers + CFG_INT nrq; + // number of bits per subquantizer + CFG_INT nbits; + KNOHWERE_DECLARE_CONFIG(FaissHnswPrqConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(m) + .description("Number of splits") + .set_default(2) + .for_train() + .set_range(1, 65536); + KNOWHERE_CONFIG_DECLARE_FIELD(nrq) + .description("Number of residual subquantizers") + .for_train() + .set_range(1, 64); + KNOWHERE_CONFIG_DECLARE_FIELD(nbits) + .description("nbits") + .set_default(8) + .for_train() + .set_range(1, 64); + } +}; + } // namespace knowhere #endif /* FAISS_HNSW_CONFIG_H */ diff --git a/thirdparty/faiss/faiss/IndexCosine.cpp b/thirdparty/faiss/faiss/IndexCosine.cpp index fc12c61bb..8f8c2a814 100644 --- a/thirdparty/faiss/faiss/IndexCosine.cpp +++ b/thirdparty/faiss/faiss/IndexCosine.cpp @@ -391,6 +391,57 @@ DistanceComputer* IndexPQCosine::get_distance_computer() const { } +////////////////////////////////////////////////////////////////////////////////// + +IndexProductResidualQuantizerCosine::IndexProductResidualQuantizerCosine( + int d, + size_t nsplits, + size_t Msub, + size_t nbits, + AdditiveQuantizer::Search_type_t search_type) : + IndexProductResidualQuantizer(d, nsplits, Msub, nbits, MetricType::METRIC_INNER_PRODUCT, search_type) { + is_cosine = true; +} + + +IndexProductResidualQuantizerCosine::IndexProductResidualQuantizerCosine() : + IndexProductResidualQuantizer() { + metric_type = MetricType::METRIC_INNER_PRODUCT; + is_cosine = true; +} + +void IndexProductResidualQuantizerCosine::add(idx_t n, const float* x) { + FAISS_THROW_IF_NOT(is_trained); + if (n == 0) { + return; + } + + // todo aguzhva: + // it is a tricky situation at this moment, because IndexProductResidualQuantizerCosine + // contains duplicate norms (one in IndexFlatCodes and one in HasInverseL2Norms). + // Norms in IndexFlatCodes are going to be removed in the future. + IndexProductResidualQuantizer::add(n, x); + inverse_norms_storage.add(x, n, d); +} + +void IndexProductResidualQuantizerCosine::reset() { + IndexProductResidualQuantizer::reset(); + inverse_norms_storage.reset(); +} + +const float* IndexProductResidualQuantizerCosine::get_inverse_l2_norms() const { + return inverse_norms_storage.inverse_l2_norms.data(); +} + +DistanceComputer* IndexProductResidualQuantizerCosine::get_distance_computer() const { + return new WithCosineNormDistanceComputer( + this->get_inverse_l2_norms(), + this->d, + std::unique_ptr(IndexProductResidualQuantizer::get_FlatCodesDistanceComputer()) + ); +} + + ////////////////////////////////////////////////////////////////////////////////// // @@ -436,6 +487,31 @@ void IndexHNSWPQCosine::train(idx_t n, const float* x) { (dynamic_cast(storage))->pq.compute_sdc_table(); } +// +IndexHNSWProductResidualQuantizer::IndexHNSWProductResidualQuantizer() = default; + +IndexHNSWProductResidualQuantizer::IndexHNSWProductResidualQuantizer( + int d, + size_t prq_nsplits, + size_t prq_Msub, + size_t prq_nbits, + size_t M, + MetricType metric, + AdditiveQuantizer::Search_type_t prq_search_type +) : IndexHNSW(new IndexProductResidualQuantizer(d, prq_nsplits, prq_Msub, prq_nbits, metric, prq_search_type), M) {} + +// +IndexHNSWProductResidualQuantizerCosine::IndexHNSWProductResidualQuantizerCosine() = default; + +IndexHNSWProductResidualQuantizerCosine::IndexHNSWProductResidualQuantizerCosine( + int d, + size_t prq_nsplits, + size_t prq_Msub, + size_t prq_nbits, + size_t M, + AdditiveQuantizer::Search_type_t prq_search_type +) : IndexHNSW(new IndexHNSWProductResidualQuantizerCosine(d, prq_nsplits, prq_Msub, prq_nbits, prq_search_type), M) {} + } diff --git a/thirdparty/faiss/faiss/IndexCosine.h b/thirdparty/faiss/faiss/IndexCosine.h index 3bb3eae7a..6fe392b36 100644 --- a/thirdparty/faiss/faiss/IndexCosine.h +++ b/thirdparty/faiss/faiss/IndexCosine.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -138,6 +139,26 @@ struct IndexPQCosine : IndexPQ, HasInverseL2Norms { const float* get_inverse_l2_norms() const override; }; +// +struct IndexProductResidualQuantizerCosine : IndexProductResidualQuantizer, HasInverseL2Norms { + L2NormsStorage inverse_norms_storage; + + IndexProductResidualQuantizerCosine( + int d, ///< dimensionality of the input vectors + size_t nsplits, ///< number of residual quantizers + size_t Msub, ///< number of subquantizers per RQ + size_t nbits, ///< number of bit per subvector index + AdditiveQuantizer::Search_type_t search_type = AdditiveQuantizer::ST_decompress); + + IndexProductResidualQuantizerCosine(); + + void add(idx_t n, const float* x) override; + void reset() override; + + DistanceComputer* get_distance_computer() const override; + + const float* get_inverse_l2_norms() const override; +}; // struct IndexHNSWFlatCosine : IndexHNSW { @@ -166,5 +187,31 @@ struct IndexHNSWPQCosine : IndexHNSW { void train(idx_t n, const float* x) override; }; +// +struct IndexHNSWProductResidualQuantizer : IndexHNSW { + IndexHNSWProductResidualQuantizer(); + IndexHNSWProductResidualQuantizer( + int d, ///< dimensionality of the input vectors + size_t prq_nsplits, ///< number of residual quantizers + size_t prq_Msub, ///< number of subquantizers per RQ + size_t prq_nbits, ///< number of bit per subvector index + size_t M, /// HNSW Param + MetricType metric = METRIC_L2, + AdditiveQuantizer::Search_type_t prq_search_type = AdditiveQuantizer::ST_decompress + ); +}; + +struct IndexHNSWProductResidualQuantizerCosine : IndexHNSW { + IndexHNSWProductResidualQuantizerCosine(); + IndexHNSWProductResidualQuantizerCosine( + int d, ///< dimensionality of the input vectors + size_t prq_nsplits, ///< number of residual quantizers + size_t prq_Msub, ///< number of subquantizers per RQ + size_t prq_nbits, ///< number of bit per subvector index + size_t M, /// HNSW Param + AdditiveQuantizer::Search_type_t prq_search_type = AdditiveQuantizer::ST_decompress + ); +}; + } diff --git a/thirdparty/faiss/faiss/impl/index_read.cpp b/thirdparty/faiss/faiss/impl/index_read.cpp index 9889342f5..599953208 100644 --- a/thirdparty/faiss/faiss/impl/index_read.cpp +++ b/thirdparty/faiss/faiss/impl/index_read.cpp @@ -813,6 +813,16 @@ Index* read_index(IOReader* f, int io_flags) { READ1(idxr->code_size); READVECTOR(idxr->codes); idx = idxr; + } else if (h == fourcc("IxP5")) { + auto idxpr = new IndexProductResidualQuantizerCosine(); + read_index_header(idxpr, f); + read_ProductResidualQuantizer(&idxpr->prq, f, io_flags); + READ1(idxpr->code_size); + READVECTOR(idxpr->codes); + // read inverse norms + READVECTOR(idxpr->inverse_norms_storage.inverse_l2_norms); + + idx = idxpr; } else if (h == fourcc("IxPR")) { auto idxpr = new IndexProductResidualQuantizer(); read_index_header(idxpr, f); @@ -1183,7 +1193,8 @@ Index* read_index(IOReader* f, int io_flags) { } else if ( h == fourcc("IHNf") || h == fourcc("IHNp") || h == fourcc("IHNs") || h == fourcc("IHN2") || h == fourcc("IHNc") || h == fourcc("IHN9") || - h == fourcc("IHN8") || h == fourcc("IHN7")) { + h == fourcc("IHN8") || h == fourcc("IHN7") || h == fourcc("IHN6") || + h == fourcc("IHN5")) { IndexHNSW* idxhnsw = nullptr; if (h == fourcc("IHNf")) idxhnsw = new IndexHNSWFlat(); @@ -1201,6 +1212,10 @@ Index* read_index(IOReader* f, int io_flags) { idxhnsw = new IndexHNSWSQCosine(); if (h == fourcc("IHN7")) idxhnsw = new IndexHNSWPQCosine(); + if (h == fourcc("IHN6")) + idxhnsw = new IndexHNSWProductResidualQuantizer(); + if (h == fourcc("IHN5")) + idxhnsw = new IndexHNSWProductResidualQuantizerCosine(); read_index_header(idxhnsw, f); if (h == fourcc("IHNc")) { READ1(idxhnsw->keep_max_size_level0); diff --git a/thirdparty/faiss/faiss/impl/index_write.cpp b/thirdparty/faiss/faiss/impl/index_write.cpp index d44a223ea..6a73729ed 100644 --- a/thirdparty/faiss/faiss/impl/index_write.cpp +++ b/thirdparty/faiss/faiss/impl/index_write.cpp @@ -631,6 +631,17 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { write_LocalSearchQuantizer(&idxr_2->lsq, f); WRITE1(idxr_2->code_size); WRITEVECTOR(idxr_2->codes); + } else if ( + const IndexProductResidualQuantizerCosine* idxpr = + dynamic_cast(idx)) { + uint32_t h = fourcc("IxP5"); + WRITE1(h); + write_index_header(idx, f); + write_ProductResidualQuantizer(&idxpr->prq, f); + WRITE1(idxpr->code_size); + WRITEVECTOR(idxpr->codes); + // inverse norms + WRITEVECTOR(idxpr->inverse_norms_storage.inverse_l2_norms); } else if ( const IndexProductResidualQuantizer* idxpr = dynamic_cast(idx)) { @@ -982,6 +993,8 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { : dynamic_cast(idx) ? fourcc("IHN9") : dynamic_cast(idx) ? fourcc("IHN8") : dynamic_cast(idx) ? fourcc("IHN7") + : dynamic_cast(idx) ? fourcc("IHN6") + : dynamic_cast(idx) ? fourcc("IHN5") : 0; FAISS_THROW_IF_NOT(h != 0); WRITE1(h); From e866b23bb3ec7c20faf5dd73312eac925c439ca1 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Mon, 5 Aug 2024 11:33:23 -0400 Subject: [PATCH 05/21] fix PRQ params Signed-off-by: Alexandr Guzhva --- src/index/hnsw/faiss_hnsw.cc | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 7f65e31fc..3e48e6411 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -1031,8 +1031,8 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { // create an index const bool is_cosine = IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE); - // HNSW + PQ index yields BAD recall somewhy. - // Let's build HNSW+FLAT index, then replace FLAT with PQ + // HNSW + PRQ index yields BAD recall somewhy. + // Let's build HNSW+FLAT index, then replace FLAT with PRQ std::unique_ptr hnsw_index; if (is_cosine) { @@ -1043,17 +1043,21 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); - // pq + // prq + faiss::AdditiveQuantizer::Search_type_t prq_search_type = + (metric.value() == faiss::MetricType::METRIC_INNER_PRODUCT) ? + faiss::AdditiveQuantizer::Search_type_t::ST_LUT_nonorm : + faiss::AdditiveQuantizer::Search_type_t::ST_norm_float; + std::unique_ptr prq_index; if (is_cosine) { prq_index = std::make_unique( - dim, hnsw_cfg.nrq.value(), hnsw_cfg.m.value(), hnsw_cfg.nbits.value()); + dim, hnsw_cfg.m.value(), hnsw_cfg.nrq.value(), hnsw_cfg.nbits.value(), prq_search_type); } else { prq_index = std::make_unique( - dim, hnsw_cfg.nrq.value(), hnsw_cfg.m.value(), hnsw_cfg.nbits.value(), metric.value()); + dim, hnsw_cfg.m.value(), hnsw_cfg.nrq.value(), hnsw_cfg.nbits.value(), metric.value(), prq_search_type); } - // should refine be used? std::unique_ptr final_index; if (hnsw_cfg.refine.value_or(false)) { From 7138465407e05adaf81337a7c4cfc5d76ca99c2a Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Mon, 5 Aug 2024 17:57:39 -0400 Subject: [PATCH 06/21] benchmarks Signed-off-by: Alexandr Guzhva --- benchmark/CMakeLists.txt | 1 + benchmark/benchmark_base.h | 7 +- benchmark/hdf5/benchmark_faiss_hnsw.cpp | 509 ++++++++++++++++++++++++ benchmark/hdf5/benchmark_knowhere.h | 119 ++++-- benchmark/utils.h | 3 + src/index/hnsw/faiss_hnsw_config.h | 2 +- 6 files changed, 601 insertions(+), 40 deletions(-) create mode 100644 benchmark/hdf5/benchmark_faiss_hnsw.cpp diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index 3f6ae5ed2..687ab33fd 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt @@ -51,6 +51,7 @@ benchmark_test(benchmark_float_bitset hdf5/benchmark_float_bitset.cpp) benchmark_test(benchmark_float_qps hdf5/benchmark_float_qps.cpp) benchmark_test(benchmark_float_range hdf5/benchmark_float_range.cpp) benchmark_test(benchmark_float_range_bitset hdf5/benchmark_float_range_bitset.cpp) +benchmark_test(benchmark_faiss_hnsw hdf5/benchmark_faiss_hnsw.cpp) benchmark_test(gen_hdf5_file hdf5/gen_hdf5_file.cpp) benchmark_test(gen_fbin_file hdf5/gen_fbin_file.cpp) diff --git a/benchmark/benchmark_base.h b/benchmark/benchmark_base.h index 81eb4c315..6acd79227 100644 --- a/benchmark/benchmark_base.h +++ b/benchmark/benchmark_base.h @@ -14,7 +14,10 @@ #include #include +#include #include +#include +#include #define CALC_TIME_SPAN(X) \ double t_start = elapsed(); \ @@ -39,7 +42,7 @@ class Benchmark_base { } } - inline double + static inline double elapsed() { struct timeval tv; gettimeofday(&tv, nullptr); @@ -86,7 +89,7 @@ class Benchmark_base { return (hit * 1.0f / (nq * min_k)); } - float + static float CalcRecall(const int64_t* g_ids, const int64_t* ids, int32_t nq, int32_t k) { int32_t hit = 0; for (int32_t i = 0; i < nq; i++) { diff --git a/benchmark/hdf5/benchmark_faiss_hnsw.cpp b/benchmark/hdf5/benchmark_faiss_hnsw.cpp new file mode 100644 index 000000000..1812b96cc --- /dev/null +++ b/benchmark/hdf5/benchmark_faiss_hnsw.cpp @@ -0,0 +1,509 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "benchmark_knowhere.h" +#include "knowhere/comp/brute_force.h" +#include "knowhere/comp/index_param.h" +#include "knowhere/comp/knowhere_config.h" +#include "knowhere/dataset.h" + + +knowhere::DataSetPtr +GenDataSet(int rows, int dim, const uint64_t seed = 42) { + std::mt19937 rng(seed); + std::uniform_real_distribution<> distrib(-1.0, 1.0); + float* ts = new float[rows * dim]; + for (int i = 0; i < rows * dim; ++i) { + ts[i] = (float)distrib(rng); + } + auto ds = knowhere::GenDataSet(rows, dim, ts); + ds->SetIsOwner(true); + return ds; +} + +// unlike other benchmarks, this one operates on a synthetic data +// and verifies the correctness of many-many variants of FAISS HNSW indices. +class Benchmark_Faiss_Hnsw : public Benchmark_knowhere, public ::testing::Test { +public: + template + void test_hnsw( + const knowhere::DataSetPtr& default_ds_ptr, + const knowhere::DataSetPtr& query_ds_ptr, + const knowhere::DataSetPtr& golden_result, + const std::vector& index_params, + const knowhere::Json& conf + ) { + const std::string index_type = conf[knowhere::meta::INDEX_TYPE].get(); + + // load indices + std::string index_file_name = get_index_name( + ann_test_name_, index_type, index_params); + + // our index + // first, we create an index and save it + auto index = create_index( + index_type, + index_file_name, + default_ds_ptr, + conf + ); + + // then, we force it to be loaded in order to test load & save + auto index_loaded = create_index( + index_type, + index_file_name, + default_ds_ptr, + conf + ); + + auto result = index.Search(query_ds_ptr, conf, nullptr); + auto result_loaded = index_loaded.Search(query_ds_ptr, conf, nullptr); + + // calc recall + auto recall = this->CalcRecall( + golden_result->GetIds(), + result.value()->GetIds(), + query_ds_ptr->GetRows(), + conf[knowhere::meta::TOPK].get() + ); + + auto recall_loaded = this->CalcRecall( + golden_result->GetIds(), + result_loaded.value()->GetIds(), + query_ds_ptr->GetRows(), + conf[knowhere::meta::TOPK].get() + ); + + printf("Recall is %f, %f\n", recall, recall_loaded); + + ASSERT_GE(recall, 0.9); + ASSERT_GE(recall_loaded, 0.9); + } + +protected: + void + SetUp() override { + T0_ = elapsed(); + set_ann_test_name("faiss_hnsw"); + + knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2); + + cfg_[knowhere::indexparam::HNSW_M] = 16; + cfg_[knowhere::indexparam::EFCONSTRUCTION] = 96; + cfg_[knowhere::indexparam::EF] = 64; + cfg_[knowhere::meta::TOPK] = TOPK; + + // create baseline indices here + CreateGoldenIndices(); + } + + void CreateGoldenIndices() { + const std::string golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP; + + uint64_t rng_seed = 1; + for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) { + for (const int32_t dim : DIMS) { + for (const int32_t nb : NBS) { + knowhere::Json conf = cfg_; + conf[knowhere::meta::METRIC_TYPE] = DISTANCE_TYPES[distance_type]; + conf[knowhere::meta::DIM] = dim; + + std::vector params = {(int)distance_type, dim, nb}; + + auto default_ds_ptr = GenDataSet(nb, dim, rng_seed); + rng_seed += 1; + + // create a golden index + std::string golden_index_file_name = get_index_name( + ann_test_name_, golden_index_type, params); + + create_index( + golden_index_type, + golden_index_file_name, + default_ds_ptr, + conf, + "golden " + ); + } + } + } + } + + const std::vector DISTANCE_TYPES = {"L2", "IP", "COSINE"}; + const std::vector DIMS = {13, 16, 27}; + const std::vector NBS = {16384, 9632 + 16384}; + const int32_t NQ = 256; + const int32_t TOPK = 64; + + const std::vector SQ_TYPES = {"SQ6", "SQ8", "BF16", "FP16"}; + + // todo: enable 10 and 12 bits when the PQ training code is provided + // const std::vector NBITS = {8, 10, 12}; + const std::vector NBITS = {8}; + + std::unordered_map> SQ_ALLOWED_REFINES = { + { "SQ6", {"SQ8", "BF16", "FP16", "FLAT"} }, + { "SQ8", {"BF16", "FP16", "FLAT"} }, + { "BF16", {"FLAT"} }, + { "FP16", {"FLAT"} } + }; + + std::vector PQ_ALLOWED_REFINES = { + {"SQ6", "SQ8", "BF16", "FP16", "FLAT"} + }; +}; + +// +TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWFLAT) { + const std::string& index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_FLAT; + const std::string& golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP; + + uint64_t rng_seed = 1; + for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) { + for (const int32_t dim : DIMS) { + auto query_ds_ptr = GenDataSet(NQ, dim, rng_seed + 1234567); + + for (const int32_t nb : NBS) { + knowhere::Json conf = cfg_; + conf[knowhere::meta::METRIC_TYPE] = DISTANCE_TYPES[distance_type]; + conf[knowhere::meta::DIM] = dim; + conf[knowhere::meta::ROWS] = nb; + conf[knowhere::meta::INDEX_TYPE] = index_type; + + std::vector params = {(int)distance_type, dim, nb}; + + // generate a default dataset + auto default_ds_ptr = GenDataSet(nb, dim, rng_seed); + rng_seed += 1; + + // get a golden result + std::string golden_index_file_name = get_index_name( + ann_test_name_, golden_index_type, params); + + auto golden_index = create_index( + golden_index_type, + golden_index_file_name, + default_ds_ptr, + conf, + "golden " + ); + + auto golden_result = golden_index.Search(query_ds_ptr, conf, nullptr); + + // + printf("\n"); + printf("Processing HNSW,Flat fp32 for %s distance, dim=%d, nrows=%d\n", + DISTANCE_TYPES[distance_type].c_str(), + dim, + nb); + + // test a candidate + test_hnsw( + default_ds_ptr, + query_ds_ptr, + golden_result.value(), + params, + conf); + } + } + } +} + +// +TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWSQ) { + const std::string& index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_SQ; + const std::string& golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP; + + uint64_t rng_seed = 1; + for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) { + for (const int32_t dim : DIMS) { + auto query_ds_ptr = GenDataSet(NQ, dim, rng_seed + 1234567); + + for (const int32_t nb : NBS) { + for (size_t sq_type = 0; sq_type < SQ_TYPES.size(); sq_type++) { + knowhere::Json conf = cfg_; + conf[knowhere::meta::METRIC_TYPE] = DISTANCE_TYPES[distance_type]; + conf[knowhere::meta::DIM] = dim; + conf[knowhere::meta::ROWS] = nb; + conf[knowhere::meta::INDEX_TYPE] = index_type; + conf[knowhere::indexparam::SQ_TYPE] = SQ_TYPES[sq_type]; + + std::vector params = {(int)distance_type, dim, nb, (int)sq_type}; + + // generate a default dataset + auto default_ds_ptr = GenDataSet(nb, dim, rng_seed); + rng_seed += 1; + + // get a golden result + std::string golden_index_file_name = get_index_name( + ann_test_name_, golden_index_type, params); + + auto golden_index = create_index( + golden_index_type, + golden_index_file_name, + default_ds_ptr, + conf, + "golden " + ); + + auto golden_result = golden_index.Search(query_ds_ptr, conf, nullptr); + + // + printf("\n"); + printf("Processing HNSW,SQ(%s) fp32 for %s distance, dim=%d, nrows=%d\n", + SQ_TYPES[sq_type].c_str(), + DISTANCE_TYPES[distance_type].c_str(), + dim, + nb); + + // test a candidate + test_hnsw( + default_ds_ptr, + query_ds_ptr, + golden_result.value(), + params, + conf); + + // test refines + const auto& allowed_refs = SQ_ALLOWED_REFINES[SQ_TYPES[sq_type]]; + for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; + conf_refine["refine_type"] = allowed_refs[allowed_ref_idx]; + + std::vector params_refine = + {(int)distance_type, dim, nb, (int)sq_type, (int)allowed_ref_idx}; + + // + printf("\n"); + printf("Processing HNSW,SQ(%s) with %s refine, fp32 for %s distance, dim=%d, nrows=%d\n", + SQ_TYPES[sq_type].c_str(), + allowed_refs[allowed_ref_idx].c_str(), + DISTANCE_TYPES[distance_type].c_str(), + dim, + nb); + + // test a candidate + test_hnsw( + default_ds_ptr, + query_ds_ptr, + golden_result.value(), + params_refine, + conf_refine); + } + } + } + } + } +} + + +// +TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWPQ) { + const std::string& index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_PQ; + const std::string& golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP; + + uint64_t rng_seed = 1; + for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) { + for (const int32_t dim : {16}) { + auto query_ds_ptr = GenDataSet(NQ, dim, rng_seed + 1234567); + + for (const int32_t nb : NBS) { + for (size_t nbits_type = 0; nbits_type < NBITS.size(); nbits_type++) { + const int pq_m = 4; + + knowhere::Json conf = cfg_; + conf[knowhere::meta::METRIC_TYPE] = DISTANCE_TYPES[distance_type]; + conf[knowhere::meta::DIM] = dim; + conf[knowhere::meta::ROWS] = nb; + conf[knowhere::meta::INDEX_TYPE] = index_type; + conf[knowhere::indexparam::NBITS] = NBITS[nbits_type]; + conf[knowhere::indexparam::M] = pq_m; + + std::vector params = {(int)distance_type, dim, nb, pq_m, (int)nbits_type}; + + // generate a default dataset + auto default_ds_ptr = GenDataSet(nb, dim, rng_seed); + rng_seed += 1; + + // get a golden result + std::string golden_index_file_name = get_index_name( + ann_test_name_, golden_index_type, params); + + auto golden_index = create_index( + golden_index_type, + golden_index_file_name, + default_ds_ptr, + conf, + "golden " + ); + + auto golden_result = golden_index.Search(query_ds_ptr, conf, nullptr); + + // + printf("\n"); + printf("Processing HNSW,PQ%dx%d fp32 for %s distance, dim=%d, nrows=%d\n", + pq_m, + NBITS[nbits_type], + DISTANCE_TYPES[distance_type].c_str(), + dim, + nb); + + // test a candidate + test_hnsw( + default_ds_ptr, + query_ds_ptr, + golden_result.value(), + params, + conf); + + // test refines + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES.size(); allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; + conf_refine["refine_type"] = PQ_ALLOWED_REFINES[allowed_ref_idx]; + + std::vector params_refine = + {(int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; + + // + printf("\n"); + printf("Processing HNSW,PQ%dx%d with %s refine, fp32 for %s distance, dim=%d, nrows=%d\n", + pq_m, + NBITS[nbits_type], + PQ_ALLOWED_REFINES[allowed_ref_idx].c_str(), + DISTANCE_TYPES[distance_type].c_str(), + dim, + nb); + + // test a candidate + test_hnsw( + default_ds_ptr, + query_ds_ptr, + golden_result.value(), + params_refine, + conf_refine); + } + } + } + } + } +} + + +// +TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWPRQ) { + const std::string& index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_PRQ; + const std::string& golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP; + + uint64_t rng_seed = 1; + for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) { + for (const int32_t dim : {16}) { + auto query_ds_ptr = GenDataSet(NQ, dim, rng_seed + 1234567); + + for (const int32_t nb : NBS) { + for (size_t nbits_type = 0; nbits_type < NBITS.size(); nbits_type++) { + const int prq_m = 4; + const int prq_num = 2; + + knowhere::Json conf = cfg_; + conf[knowhere::meta::METRIC_TYPE] = DISTANCE_TYPES[distance_type]; + conf[knowhere::meta::DIM] = dim; + conf[knowhere::meta::ROWS] = nb; + conf[knowhere::meta::INDEX_TYPE] = index_type; + conf[knowhere::indexparam::NBITS] = NBITS[nbits_type]; + conf[knowhere::indexparam::M] = prq_m; + conf[knowhere::indexparam::PRQ_NUM] = prq_num; + + std::vector params = {(int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type}; + + // generate a default dataset + auto default_ds_ptr = GenDataSet(nb, dim, rng_seed); + rng_seed += 1; + + // get a golden result + std::string golden_index_file_name = get_index_name( + ann_test_name_, golden_index_type, params); + + auto golden_index = create_index( + golden_index_type, + golden_index_file_name, + default_ds_ptr, + conf, + "golden " + ); + + auto golden_result = golden_index.Search(query_ds_ptr, conf, nullptr); + + // + printf("\n"); + printf("Processing HNSW,PRQ%dx%dx%d fp32 for %s distance, dim=%d, nrows=%d\n", + prq_num, + prq_m, + NBITS[nbits_type], + DISTANCE_TYPES[distance_type].c_str(), + dim, + nb); + + // test a candidate + test_hnsw( + default_ds_ptr, + query_ds_ptr, + golden_result.value(), + params, + conf); + + // test refines + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES.size(); allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; + conf_refine["refine_type"] = PQ_ALLOWED_REFINES[allowed_ref_idx]; + + std::vector params_refine = + {(int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, (int)allowed_ref_idx}; + + // + printf("\n"); + printf("Processing HNSW,PRQ%dx%dx%d with %s refine, fp32 for %s distance, dim=%d, nrows=%d\n", + prq_num, + prq_m, + NBITS[nbits_type], + PQ_ALLOWED_REFINES[allowed_ref_idx].c_str(), + DISTANCE_TYPES[distance_type].c_str(), + dim, + nb); + + // test a candidate + test_hnsw( + default_ds_ptr, + query_ds_ptr, + golden_result.value(), + params_refine, + conf_refine); + } + } + } + } + } +} \ No newline at end of file diff --git a/benchmark/hdf5/benchmark_knowhere.h b/benchmark/hdf5/benchmark_knowhere.h index 812a8190a..6cff6f5e7 100644 --- a/benchmark/hdf5/benchmark_knowhere.h +++ b/benchmark/hdf5/benchmark_knowhere.h @@ -11,7 +11,12 @@ #pragma once +#include +#include +#include #include +#include +#include #include #include "benchmark/utils.h" @@ -32,7 +37,7 @@ std::string kIPIndexPrefix = kIPIndexDir + "/ip"; class Benchmark_knowhere : public Benchmark_hdf5 { public: - void + static void write_index(knowhere::Index& index, const std::string& filename, const knowhere::Json& conf) { FileIOWriter writer(filename); @@ -53,7 +58,7 @@ class Benchmark_knowhere : public Benchmark_hdf5 { } } - void + static void read_index(knowhere::Index& index, const std::string& filename, const knowhere::Json& conf) { FileIOReader reader(filename); int64_t file_size = reader.size(); @@ -85,67 +90,107 @@ class Benchmark_knowhere : public Benchmark_hdf5 { index.Deserialize(binary_set, conf); } - template - std::string - get_index_name(const std::vector& params) { + template + static std::string + get_index_name( + const std::string& ann_test_name, + const std::string& index_type, + const std::vector& params + ) { std::string params_str = ""; for (size_t i = 0; i < params.size(); i++) { params_str += "_" + std::to_string(params[i]); } if constexpr (std::is_same_v) { - return ann_test_name_ + "_" + index_type_ + params_str + "_fp32" + ".index"; + return ann_test_name + "_" + index_type + params_str + "_fp32" + ".index"; } else if constexpr (std::is_same_v) { - return ann_test_name_ + "_" + index_type_ + params_str + "_fp16" + ".index"; + return ann_test_name + "_" + index_type + params_str + "_fp16" + ".index"; } else if constexpr (std::is_same_v) { - return ann_test_name_ + "_" + index_type_ + params_str + "_bf16" + ".index"; + return ann_test_name + "_" + index_type + params_str + "_bf16" + ".index"; } else { - return ann_test_name_ + "_" + index_type_ + params_str + ".index"; + return ann_test_name + "_" + index_type + params_str + ".index"; } } template + std::string + get_index_name(const std::vector& params) { + return this->get_index_name(ann_test_name_, index_type_, params); + } + + template knowhere::Index - create_index(const std::string& index_file_name, const knowhere::Json& conf) { + create_index( + const std::string& index_type, + const std::string& index_file_name, + const knowhere::DataSetPtr& default_ds_ptr, + const knowhere::Json& conf, + const std::optional& additional_name = std::nullopt + ) { + std::string additional_name_s = additional_name.value_or(""); + + printf("[%.3f s] Creating %sindex \"%s\"\n", + get_time_diff(), + additional_name_s.c_str(), + index_type.c_str()); + auto version = knowhere::Version::GetCurrentVersion().VersionNumber(); - printf("[%.3f s] Creating index \"%s\"\n", get_time_diff(), index_type_.c_str()); - index_ = knowhere::IndexFactory::Instance().Create(index_type_, version); + auto index = knowhere::IndexFactory::Instance().Create(index_type, version); try { - printf("[%.3f s] Reading index file: %s\n", get_time_diff(), index_file_name.c_str()); - read_index(index_.value(), index_file_name, conf); + printf("[%.3f s] Reading %sindex file: %s\n", + get_time_diff(), + additional_name_s.c_str(), + index_file_name.c_str()); + + read_index(index.value(), index_file_name, conf); } catch (...) { - printf("[%.3f s] Building all on %d vectors\n", get_time_diff(), nb_); - auto ds_ptr = knowhere::GenDataSet(nb_, dim_, xb_); - auto base = knowhere::ConvertToDataTypeIfNeeded(ds_ptr); - index_.value().Build(base, conf); + printf("[%.3f s] Building %sindex all on %ld vectors\n", + get_time_diff(), + additional_name_s.c_str(), + default_ds_ptr->GetRows()); + + auto base = knowhere::ConvertToDataTypeIfNeeded(default_ds_ptr); + index.value().Build(base, conf); - printf("[%.3f s] Writing index file: %s\n", get_time_diff(), index_file_name.c_str()); - write_index(index_.value(), index_file_name, conf); + printf("[%.3f s] Writing %sindex file: %s\n", + get_time_diff(), + additional_name_s.c_str(), + index_file_name.c_str()); + + write_index(index.value(), index_file_name, conf); } - return index_.value(); + + return index.value(); + } + + template + knowhere::Index + create_index(const std::string& index_file_name, const knowhere::Json& conf) { + auto idx = this->create_index( + index_type_, + index_file_name, + knowhere::GenDataSet(nb_, dim_, xb_), + conf + ); + index_ = idx; + return idx; } knowhere::Index create_golden_index(const knowhere::Json& conf) { - auto version = knowhere::Version::GetCurrentVersion().VersionNumber(); golden_index_type_ = knowhere::IndexEnum::INDEX_FAISS_IDMAP; - std::string golden_index_file_name = ann_test_name_ + "_" + golden_index_type_ + "_GOLDEN" + ".index"; - printf("[%.3f s] Creating golden index \"%s\"\n", get_time_diff(), golden_index_type_.c_str()); - golden_index_ = knowhere::IndexFactory::Instance().Create(golden_index_type_, version); - - try { - printf("[%.3f s] Reading golden index file: %s\n", get_time_diff(), golden_index_file_name.c_str()); - read_index(golden_index_.value(), golden_index_file_name, conf); - } catch (...) { - printf("[%.3f s] Building golden index on %d vectors\n", get_time_diff(), nb_); - knowhere::DataSetPtr ds_ptr = knowhere::GenDataSet(nb_, dim_, xb_); - golden_index_.value().Build(ds_ptr, conf); - printf("[%.3f s] Writing golden index file: %s\n", get_time_diff(), golden_index_file_name.c_str()); - write_index(golden_index_.value(), golden_index_file_name, conf); - } - return golden_index_.value(); + auto idx = this->create_index( + golden_index_type_, + golden_index_file_name, + knowhere::GenDataSet(nb_, dim_, xb_), + conf, + "golden " + ); + golden_index_ = idx; + return idx; } void diff --git a/benchmark/utils.h b/benchmark/utils.h index bb6cb066a..a9bc5948d 100644 --- a/benchmark/utils.h +++ b/benchmark/utils.h @@ -11,6 +11,9 @@ #pragma once +#include +#include +#include #include #include #include diff --git a/src/index/hnsw/faiss_hnsw_config.h b/src/index/hnsw/faiss_hnsw_config.h index 77116ce9a..b88dc1ea0 100644 --- a/src/index/hnsw/faiss_hnsw_config.h +++ b/src/index/hnsw/faiss_hnsw_config.h @@ -118,7 +118,7 @@ class FaissHnswConfig : public BaseConfig { bool WhetherAcceptableRefineType(const std::string& refine_type) { // 'flat' is identical to 'fp32' std::vector allowed_list = { - "SQ8", "FP16", "BF16", "FP32", "FLAT"}; + "SQ6", "SQ8", "FP16", "BF16", "FP32", "FLAT"}; // todo: tolower() From ea4fc9442f69a5f22836ade245e6d6a668643bac Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Mon, 5 Aug 2024 19:28:40 -0400 Subject: [PATCH 07/21] tweaks for ScalarQuantizer Signed-off-by: Alexandr Guzhva --- .../faiss/faiss/impl/ScalarQuantizerCodec.h | 45 ++++++------ .../faiss/impl/ScalarQuantizerCodec_avx.h | 60 ++++++++-------- .../faiss/impl/ScalarQuantizerCodec_avx512.h | 72 +++++++++---------- .../faiss/impl/ScalarQuantizerCodec_neon.h | 60 ++++++++-------- 4 files changed, 121 insertions(+), 116 deletions(-) diff --git a/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec.h b/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec.h index 6a20a0ca8..48eea0c45 100644 --- a/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec.h +++ b/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec.h @@ -118,11 +118,16 @@ struct Codec6bit { * through a codec *******************************************************************/ -template +enum class QuantizerTemplateScaling { + UNIFORM = 0, + NON_UNIFORM = 1 +}; + +template struct QuantizerTemplate {}; template -struct QuantizerTemplate : SQuantizer { +struct QuantizerTemplate : SQuantizer { const size_t d; const float vmin, vdiff; @@ -160,7 +165,7 @@ struct QuantizerTemplate : SQuantizer { }; template -struct QuantizerTemplate : SQuantizer { +struct QuantizerTemplate : SQuantizer { const size_t d; const float *vmin, *vdiff; @@ -330,19 +335,19 @@ SQuantizer* select_quantizer_1( const std::vector& trained) { switch (qtype) { case ScalarQuantizer::QT_8bit: - return new QuantizerTemplate( + return new QuantizerTemplate( d, trained); case ScalarQuantizer::QT_6bit: - return new QuantizerTemplate( + return new QuantizerTemplate( d, trained); case ScalarQuantizer::QT_4bit: - return new QuantizerTemplate( + return new QuantizerTemplate( d, trained); case ScalarQuantizer::QT_8bit_uniform: - return new QuantizerTemplate( + return new QuantizerTemplate( d, trained); case ScalarQuantizer::QT_4bit_uniform: - return new QuantizerTemplate( + return new QuantizerTemplate( d, trained); case ScalarQuantizer::QT_fp16: return new QuantizerFP16(d, trained); @@ -547,31 +552,31 @@ SQDistanceComputer* select_distance_computer( switch (qtype) { case ScalarQuantizer::QT_8bit_uniform: return new DCTemplate< - QuantizerTemplate, + QuantizerTemplate, Sim, SIMDWIDTH>(d, trained); case ScalarQuantizer::QT_4bit_uniform: return new DCTemplate< - QuantizerTemplate, + QuantizerTemplate, Sim, SIMDWIDTH>(d, trained); case ScalarQuantizer::QT_8bit: return new DCTemplate< - QuantizerTemplate, + QuantizerTemplate, Sim, SIMDWIDTH>(d, trained); case ScalarQuantizer::QT_6bit: return new DCTemplate< - QuantizerTemplate, + QuantizerTemplate, Sim, SIMDWIDTH>(d, trained); case ScalarQuantizer::QT_4bit: return new DCTemplate< - QuantizerTemplate, + QuantizerTemplate, Sim, SIMDWIDTH>(d, trained); @@ -648,7 +653,7 @@ InvertedListScanner* sel2_InvertedListScanner( } } -template +template InvertedListScanner* sel12_InvertedListScanner( const ScalarQuantizer* sq, const Index* quantizer, @@ -656,7 +661,7 @@ InvertedListScanner* sel12_InvertedListScanner( const IDSelector* sel, bool r) { constexpr int SIMDWIDTH = Similarity::simdwidth; - using QuantizerClass = QuantizerTemplate; + using QuantizerClass = QuantizerTemplate; using DCClass = DCTemplate; return sel2_InvertedListScanner( sq, quantizer, store_pairs, sel, r); @@ -672,19 +677,19 @@ InvertedListScanner* sel1_InvertedListScanner( constexpr int SIMDWIDTH = Similarity::simdwidth; switch (sq->qtype) { case ScalarQuantizer::QT_8bit_uniform: - return sel12_InvertedListScanner( + return sel12_InvertedListScanner( sq, quantizer, store_pairs, sel, r); case ScalarQuantizer::QT_4bit_uniform: - return sel12_InvertedListScanner( + return sel12_InvertedListScanner( sq, quantizer, store_pairs, sel, r); case ScalarQuantizer::QT_8bit: - return sel12_InvertedListScanner( + return sel12_InvertedListScanner( sq, quantizer, store_pairs, sel, r); case ScalarQuantizer::QT_4bit: - return sel12_InvertedListScanner( + return sel12_InvertedListScanner( sq, quantizer, store_pairs, sel, r); case ScalarQuantizer::QT_6bit: - return sel12_InvertedListScanner( + return sel12_InvertedListScanner( sq, quantizer, store_pairs, sel, r); case ScalarQuantizer::QT_fp16: return sel2_InvertedListScanner +template struct QuantizerTemplate_avx {}; template -struct QuantizerTemplate_avx - : public QuantizerTemplate { +struct QuantizerTemplate_avx + : public QuantizerTemplate { QuantizerTemplate_avx(size_t d, const std::vector& trained) - : QuantizerTemplate(d, trained) {} + : QuantizerTemplate(d, trained) {} }; template -struct QuantizerTemplate_avx - : public QuantizerTemplate { +struct QuantizerTemplate_avx + : public QuantizerTemplate { QuantizerTemplate_avx(size_t d, const std::vector& trained) - : QuantizerTemplate(d, trained) {} + : QuantizerTemplate(d, trained) {} FAISS_ALWAYS_INLINE __m256 reconstruct_8_components(const uint8_t* code, int i) const { @@ -143,17 +143,17 @@ struct QuantizerTemplate_avx }; template -struct QuantizerTemplate_avx - : public QuantizerTemplate { +struct QuantizerTemplate_avx + : public QuantizerTemplate { QuantizerTemplate_avx(size_t d, const std::vector& trained) - : QuantizerTemplate(d, trained) {} + : QuantizerTemplate(d, trained) {} }; template -struct QuantizerTemplate_avx - : public QuantizerTemplate { +struct QuantizerTemplate_avx + : public QuantizerTemplate { QuantizerTemplate_avx(size_t d, const std::vector& trained) - : QuantizerTemplate(d, trained) {} + : QuantizerTemplate(d, trained) {} FAISS_ALWAYS_INLINE __m256 reconstruct_8_components(const uint8_t* code, int i) const { @@ -278,19 +278,19 @@ SQuantizer* select_quantizer_1_avx( const std::vector& trained) { switch (qtype) { case QuantizerType::QT_8bit: - return new QuantizerTemplate_avx( + return new QuantizerTemplate_avx( d, trained); case QuantizerType::QT_6bit: - return new QuantizerTemplate_avx( + return new QuantizerTemplate_avx( d, trained); case QuantizerType::QT_4bit: - return new QuantizerTemplate_avx( + return new QuantizerTemplate_avx( d, trained); case QuantizerType::QT_8bit_uniform: - return new QuantizerTemplate_avx( + return new QuantizerTemplate_avx( d, trained); case QuantizerType::QT_4bit_uniform: - return new QuantizerTemplate_avx( + return new QuantizerTemplate_avx( d, trained); case QuantizerType::QT_fp16: return new QuantizerFP16_avx(d, trained); @@ -606,31 +606,31 @@ SQDistanceComputer* select_distance_computer_avx( switch (qtype) { case QuantizerType::QT_8bit_uniform: return new DCTemplate_avx< - QuantizerTemplate_avx, + QuantizerTemplate_avx, Sim, SIMDWIDTH>(d, trained); case QuantizerType::QT_4bit_uniform: return new DCTemplate_avx< - QuantizerTemplate_avx, + QuantizerTemplate_avx, Sim, SIMDWIDTH>(d, trained); case QuantizerType::QT_8bit: return new DCTemplate_avx< - QuantizerTemplate_avx, + QuantizerTemplate_avx, Sim, SIMDWIDTH>(d, trained); case QuantizerType::QT_6bit: return new DCTemplate_avx< - QuantizerTemplate_avx, + QuantizerTemplate_avx, Sim, SIMDWIDTH>(d, trained); case QuantizerType::QT_4bit: return new DCTemplate_avx< - QuantizerTemplate_avx, + QuantizerTemplate_avx, Sim, SIMDWIDTH>(d, trained); @@ -677,7 +677,7 @@ InvertedListScanner* sel2_InvertedListScanner_avx( sq, quantizer, store_pairs, sel, r); } -template +template InvertedListScanner* sel12_InvertedListScanner_avx( const ScalarQuantizer* sq, const Index* quantizer, @@ -685,7 +685,7 @@ InvertedListScanner* sel12_InvertedListScanner_avx( const IDSelector* sel, bool r) { constexpr int SIMDWIDTH = Similarity::simdwidth; - using QuantizerClass = QuantizerTemplate_avx; + using QuantizerClass = QuantizerTemplate_avx; using DCClass = DCTemplate_avx; return sel2_InvertedListScanner_avx( sq, quantizer, store_pairs, sel, r); @@ -704,27 +704,27 @@ InvertedListScanner* sel1_InvertedListScanner_avx( return sel12_InvertedListScanner_avx< Similarity, Codec8bit_avx, - true>(sq, quantizer, store_pairs, sel, r); + QuantizerTemplateScaling::UNIFORM>(sq, quantizer, store_pairs, sel, r); case QuantizerType::QT_4bit_uniform: return sel12_InvertedListScanner_avx< Similarity, Codec4bit_avx, - true>(sq, quantizer, store_pairs, sel, r); + QuantizerTemplateScaling::UNIFORM>(sq, quantizer, store_pairs, sel, r); case QuantizerType::QT_8bit: return sel12_InvertedListScanner_avx< Similarity, Codec8bit_avx, - false>(sq, quantizer, store_pairs, sel, r); + QuantizerTemplateScaling::NON_UNIFORM>(sq, quantizer, store_pairs, sel, r); case QuantizerType::QT_4bit: return sel12_InvertedListScanner_avx< Similarity, Codec4bit_avx, - false>(sq, quantizer, store_pairs, sel, r); + QuantizerTemplateScaling::NON_UNIFORM>(sq, quantizer, store_pairs, sel, r); case QuantizerType::QT_6bit: return sel12_InvertedListScanner_avx< Similarity, Codec6bit_avx, - false>(sq, quantizer, store_pairs, sel, r); + QuantizerTemplateScaling::NON_UNIFORM>(sq, quantizer, store_pairs, sel, r); case QuantizerType::QT_fp16: return sel2_InvertedListScanner_avx, diff --git a/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h b/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h index bfb7f40f9..b90e5f318 100644 --- a/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h +++ b/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h @@ -158,28 +158,28 @@ struct Codec6bit_avx512 : public Codec6bit_avx { * through a codec *******************************************************************/ -template +template struct QuantizerTemplate_avx512 {}; template -struct QuantizerTemplate_avx512 - : public QuantizerTemplate_avx { +struct QuantizerTemplate_avx512 + : public QuantizerTemplate_avx { QuantizerTemplate_avx512(size_t d, const std::vector& trained) - : QuantizerTemplate_avx(d, trained) {} + : QuantizerTemplate_avx(d, trained) {} }; template -struct QuantizerTemplate_avx512 - : public QuantizerTemplate_avx { +struct QuantizerTemplate_avx512 + : public QuantizerTemplate_avx { QuantizerTemplate_avx512(size_t d, const std::vector& trained) - : QuantizerTemplate_avx(d, trained) {} + : QuantizerTemplate_avx(d, trained) {} }; template -struct QuantizerTemplate_avx512 - : public QuantizerTemplate_avx { +struct QuantizerTemplate_avx512 + : public QuantizerTemplate_avx { QuantizerTemplate_avx512(size_t d, const std::vector& trained) - : QuantizerTemplate_avx(d, trained) {} + : QuantizerTemplate_avx(d, trained) {} FAISS_ALWAYS_INLINE __m512 reconstruct_16_components(const uint8_t* code, int i) const { @@ -190,24 +190,24 @@ struct QuantizerTemplate_avx512 }; template -struct QuantizerTemplate_avx512 - : public QuantizerTemplate_avx { +struct QuantizerTemplate_avx512 + : public QuantizerTemplate_avx { QuantizerTemplate_avx512(size_t d, const std::vector& trained) - : QuantizerTemplate_avx(d, trained) {} + : QuantizerTemplate_avx(d, trained) {} }; template -struct QuantizerTemplate_avx512 - : public QuantizerTemplate_avx { +struct QuantizerTemplate_avx512 + : public QuantizerTemplate_avx { QuantizerTemplate_avx512(size_t d, const std::vector& trained) - : QuantizerTemplate_avx(d, trained) {} + : QuantizerTemplate_avx(d, trained) {} }; template -struct QuantizerTemplate_avx512 - : public QuantizerTemplate_avx { +struct QuantizerTemplate_avx512 + : public QuantizerTemplate_avx { QuantizerTemplate_avx512(size_t d, const std::vector& trained) - : QuantizerTemplate_avx(d, trained) {} + : QuantizerTemplate_avx(d, trained) {} FAISS_ALWAYS_INLINE __m512 reconstruct_16_components(const uint8_t* code, int i) const { @@ -365,27 +365,27 @@ SQuantizer* select_quantizer_1_avx512( case QuantizerType::QT_8bit: return new QuantizerTemplate_avx512< Codec8bit_avx512, - false, + QuantizerTemplateScaling::NON_UNIFORM, SIMDWIDTH>(d, trained); case QuantizerType::QT_6bit: return new QuantizerTemplate_avx512< Codec6bit_avx512, - false, + QuantizerTemplateScaling::NON_UNIFORM, SIMDWIDTH>(d, trained); case QuantizerType::QT_4bit: return new QuantizerTemplate_avx512< Codec4bit_avx512, - false, + QuantizerTemplateScaling::NON_UNIFORM, SIMDWIDTH>(d, trained); case QuantizerType::QT_8bit_uniform: return new QuantizerTemplate_avx512< Codec8bit_avx512, - true, + QuantizerTemplateScaling::UNIFORM, SIMDWIDTH>(d, trained); case QuantizerType::QT_4bit_uniform: return new QuantizerTemplate_avx512< Codec4bit_avx512, - true, + QuantizerTemplateScaling::UNIFORM, SIMDWIDTH>(d, trained); case QuantizerType::QT_fp16: return new QuantizerFP16_avx512(d, trained); @@ -734,13 +734,13 @@ SQDistanceComputer* select_distance_computer_avx512( switch (qtype) { case QuantizerType::QT_8bit_uniform: return new DCTemplate_avx512< - QuantizerTemplate_avx512, + QuantizerTemplate_avx512, Sim, SIMDWIDTH>(d, trained); case QuantizerType::QT_4bit_uniform: return new DCTemplate_avx512< - QuantizerTemplate_avx512, + QuantizerTemplate_avx512, Sim, SIMDWIDTH>(d, trained); @@ -748,7 +748,7 @@ SQDistanceComputer* select_distance_computer_avx512( return new DCTemplate_avx512< QuantizerTemplate_avx512< Codec8bit_avx512, - false, + QuantizerTemplateScaling::NON_UNIFORM, SIMDWIDTH>, Sim, SIMDWIDTH>(d, trained); @@ -757,7 +757,7 @@ SQDistanceComputer* select_distance_computer_avx512( return new DCTemplate_avx512< QuantizerTemplate_avx512< Codec6bit_avx512, - false, + QuantizerTemplateScaling::NON_UNIFORM, SIMDWIDTH>, Sim, SIMDWIDTH>(d, trained); @@ -766,7 +766,7 @@ SQDistanceComputer* select_distance_computer_avx512( return new DCTemplate_avx512< QuantizerTemplate_avx512< Codec4bit_avx512, - false, + QuantizerTemplateScaling::NON_UNIFORM, SIMDWIDTH>, Sim, SIMDWIDTH>(d, trained); @@ -815,7 +815,7 @@ InvertedListScanner* sel2_InvertedListScanner_avx512( sq, quantizer, store_pairs, sel, r); } -template +template InvertedListScanner* sel12_InvertedListScanner_avx512( const ScalarQuantizer* sq, const Index* quantizer, @@ -823,7 +823,7 @@ InvertedListScanner* sel12_InvertedListScanner_avx512( const IDSelector* sel, bool r) { constexpr int SIMDWIDTH = Similarity::simdwidth; - using QuantizerClass = QuantizerTemplate_avx512; + using QuantizerClass = QuantizerTemplate_avx512; using DCClass = DCTemplate_avx512; return sel2_InvertedListScanner_avx512( sq, quantizer, store_pairs, sel, r); @@ -842,27 +842,27 @@ InvertedListScanner* sel1_InvertedListScanner_avx512( return sel12_InvertedListScanner_avx512< Similarity, Codec8bit_avx512, - true>(sq, quantizer, store_pairs, sel, r); + QuantizerTemplateScaling::UNIFORM>(sq, quantizer, store_pairs, sel, r); case QuantizerType::QT_4bit_uniform: return sel12_InvertedListScanner_avx512< Similarity, Codec4bit_avx512, - true>(sq, quantizer, store_pairs, sel, r); + QuantizerTemplateScaling::UNIFORM>(sq, quantizer, store_pairs, sel, r); case QuantizerType::QT_8bit: return sel12_InvertedListScanner_avx512< Similarity, Codec8bit_avx512, - false>(sq, quantizer, store_pairs, sel, r); + QuantizerTemplateScaling::NON_UNIFORM>(sq, quantizer, store_pairs, sel, r); case QuantizerType::QT_4bit: return sel12_InvertedListScanner_avx512< Similarity, Codec4bit_avx512, - false>(sq, quantizer, store_pairs, sel, r); + QuantizerTemplateScaling::NON_UNIFORM>(sq, quantizer, store_pairs, sel, r); case QuantizerType::QT_6bit: return sel12_InvertedListScanner_avx512< Similarity, Codec6bit_avx512, - false>(sq, quantizer, store_pairs, sel, r); + QuantizerTemplateScaling::NON_UNIFORM>(sq, quantizer, store_pairs, sel, r); case QuantizerType::QT_fp16: return sel2_InvertedListScanner_avx512, diff --git a/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_neon.h b/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_neon.h index 25cc36503..cef65d8b5 100644 --- a/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_neon.h +++ b/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_neon.h @@ -74,21 +74,21 @@ struct Codec6bit_neon : public Codec6bit { * through a codec *******************************************************************/ -template +template struct QuantizerTemplate_neon {}; template -struct QuantizerTemplate_neon - : public QuantizerTemplate { +struct QuantizerTemplate_neon + : public QuantizerTemplate { QuantizerTemplate_neon(size_t d, const std::vector& trained) - : QuantizerTemplate(d, trained) {} + : QuantizerTemplate(d, trained) {} }; template -struct QuantizerTemplate_neon - : public QuantizerTemplate { +struct QuantizerTemplate_neon + : public QuantizerTemplate { QuantizerTemplate_neon(size_t d, const std::vector& trained) - : QuantizerTemplate(d, trained) {} + : QuantizerTemplate(d, trained) {} FAISS_ALWAYS_INLINE float32x4x2_t reconstruct_8_components(const uint8_t* code, int i) const { @@ -107,17 +107,17 @@ struct QuantizerTemplate_neon }; template -struct QuantizerTemplate_neon - : public QuantizerTemplate { +struct QuantizerTemplate_neon + : public QuantizerTemplate { QuantizerTemplate_neon(size_t d, const std::vector& trained) - : QuantizerTemplate(d, trained) {} + : QuantizerTemplate(d, trained) {} }; template -struct QuantizerTemplate_neon - : public QuantizerTemplate { +struct QuantizerTemplate_neon + : public QuantizerTemplate { QuantizerTemplate_neon(size_t d, const std::vector& trained) - : QuantizerTemplate(d, trained) {} + : QuantizerTemplate(d, trained) {} FAISS_ALWAYS_INLINE float32x4x2_t reconstruct_8_components(const uint8_t* code, int i) const { @@ -258,19 +258,19 @@ SQuantizer* select_quantizer_1_neon( const std::vector& trained) { switch (qtype) { case QuantizerType::QT_8bit: - return new QuantizerTemplate_neon( + return new QuantizerTemplate_neon( d, trained); case QuantizerType::QT_6bit: - return new QuantizerTemplate_neon( + return new QuantizerTemplate_neon( d, trained); case QuantizerType::QT_4bit: - return new QuantizerTemplate_neon( + return new QuantizerTemplate_neon( d, trained); case QuantizerType::QT_8bit_uniform: - return new QuantizerTemplate_neon( + return new QuantizerTemplate_neon( d, trained); case QuantizerType::QT_4bit_uniform: - return new QuantizerTemplate_neon( + return new QuantizerTemplate_neon( d, trained); case QuantizerType::QT_fp16: return new QuantizerFP16_neon(d, trained); @@ -588,31 +588,31 @@ SQDistanceComputer* select_distance_computer_neon( switch (qtype) { case QuantizerType::QT_8bit_uniform: return new DCTemplate_neon< - QuantizerTemplate_neon, + QuantizerTemplate_neon, Sim, SIMDWIDTH>(d, trained); case QuantizerType::QT_4bit_uniform: return new DCTemplate_neon< - QuantizerTemplate_neon, + QuantizerTemplate_neon, Sim, SIMDWIDTH>(d, trained); case QuantizerType::QT_8bit: return new DCTemplate_neon< - QuantizerTemplate_neon, + QuantizerTemplate_neon, Sim, SIMDWIDTH>(d, trained); case QuantizerType::QT_6bit: return new DCTemplate_neon< - QuantizerTemplate_neon, + QuantizerTemplate_neon, Sim, SIMDWIDTH>(d, trained); case QuantizerType::QT_4bit: return new DCTemplate_neon< - QuantizerTemplate_neon, + QuantizerTemplate_neon, Sim, SIMDWIDTH>(d, trained); @@ -659,7 +659,7 @@ InvertedListScanner* sel2_InvertedListScanner_neon( sq, quantizer, store_pairs, sel, r); } -template +template InvertedListScanner* sel12_InvertedListScanner_neon( const ScalarQuantizer* sq, const Index* quantizer, @@ -667,7 +667,7 @@ InvertedListScanner* sel12_InvertedListScanner_neon( const IDSelector* sel, bool r) { constexpr int SIMDWIDTH = Similarity::simdwidth; - using QuantizerClass = QuantizerTemplate_neon; + using QuantizerClass = QuantizerTemplate_neon; using DCClass = DCTemplate_neon; return sel2_InvertedListScanner_neon( sq, quantizer, store_pairs, sel, r); @@ -686,27 +686,27 @@ InvertedListScanner* sel1_InvertedListScanner_neon( return sel12_InvertedListScanner_neon< Similarity, Codec8bit_neon, - true>(sq, quantizer, store_pairs, sel, r); + QuantizerTemplateScaling::UNIFORM>(sq, quantizer, store_pairs, sel, r); case QuantizerType::QT_4bit_uniform: return sel12_InvertedListScanner_neon< Similarity, Codec4bit_neon, - true>(sq, quantizer, store_pairs, sel, r); + QuantizerTemplateScaling::UNIFORM>(sq, quantizer, store_pairs, sel, r); case QuantizerType::QT_8bit: return sel12_InvertedListScanner_neon< Similarity, Codec8bit_neon, - false>(sq, quantizer, store_pairs, sel, r); + QuantizerTemplateScaling::NON_UNIFORM>(sq, quantizer, store_pairs, sel, r); case QuantizerType::QT_4bit: return sel12_InvertedListScanner_neon< Similarity, Codec4bit_neon, - false>(sq, quantizer, store_pairs, sel, r); + QuantizerTemplateScaling::NON_UNIFORM>(sq, quantizer, store_pairs, sel, r); case QuantizerType::QT_6bit: return sel12_InvertedListScanner_neon< Similarity, Codec6bit_neon, - false>(sq, quantizer, store_pairs, sel, r); + QuantizerTemplateScaling::NON_UNIFORM>(sq, quantizer, store_pairs, sel, r); case QuantizerType::QT_fp16: return sel2_InvertedListScanner_neon, From 44bbc8b8d7d7067548c824bca5dd219e708f6e1c Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Tue, 6 Aug 2024 10:19:47 -0400 Subject: [PATCH 08/21] add fp16 / bf16 support for INDEX_FAISS_FLAT Signed-off-by: Alexandr Guzhva --- benchmark/hdf5/benchmark_faiss_hnsw.cpp | 44 ++++- include/knowhere/index/index_table.h | 2 + include/knowhere/utils.h | 57 +++++-- src/index/hnsw/faiss_hnsw.cc | 212 +++++++++++++++++++++--- 4 files changed, 273 insertions(+), 42 deletions(-) diff --git a/benchmark/hdf5/benchmark_faiss_hnsw.cpp b/benchmark/hdf5/benchmark_faiss_hnsw.cpp index 1812b96cc..625e54f36 100644 --- a/benchmark/hdf5/benchmark_faiss_hnsw.cpp +++ b/benchmark/hdf5/benchmark_faiss_hnsw.cpp @@ -74,21 +74,23 @@ class Benchmark_Faiss_Hnsw : public Benchmark_knowhere, public ::testing::Test { conf ); - auto result = index.Search(query_ds_ptr, conf, nullptr); - auto result_loaded = index_loaded.Search(query_ds_ptr, conf, nullptr); + auto query_t_ds_ptr = knowhere::ConvertToDataTypeIfNeeded(query_ds_ptr); + + auto result = index.Search(query_t_ds_ptr, conf, nullptr); + auto result_loaded = index_loaded.Search(query_t_ds_ptr, conf, nullptr); // calc recall auto recall = this->CalcRecall( golden_result->GetIds(), result.value()->GetIds(), - query_ds_ptr->GetRows(), + query_t_ds_ptr->GetRows(), conf[knowhere::meta::TOPK].get() ); auto recall_loaded = this->CalcRecall( golden_result->GetIds(), result_loaded.value()->GetIds(), - query_ds_ptr->GetRows(), + query_t_ds_ptr->GetRows(), conf[knowhere::meta::TOPK].get() ); @@ -208,7 +210,7 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWFLAT) { auto golden_result = golden_index.Search(query_ds_ptr, conf, nullptr); - // + // fp32 printf("\n"); printf("Processing HNSW,Flat fp32 for %s distance, dim=%d, nrows=%d\n", DISTANCE_TYPES[distance_type].c_str(), @@ -222,6 +224,36 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWFLAT) { golden_result.value(), params, conf); + + // fp16 + printf("\n"); + printf("Processing HNSW,Flat fp16 for %s distance, dim=%d, nrows=%d\n", + DISTANCE_TYPES[distance_type].c_str(), + dim, + nb); + + // test a candidate + test_hnsw( + default_ds_ptr, + query_ds_ptr, + golden_result.value(), + params, + conf); + + // bf32 + printf("\n"); + printf("Processing HNSW,Flat bf16 for %s distance, dim=%d, nrows=%d\n", + DISTANCE_TYPES[distance_type].c_str(), + dim, + nb); + + // test a candidate + test_hnsw( + default_ds_ptr, + query_ds_ptr, + golden_result.value(), + params, + conf); } } } @@ -506,4 +538,4 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWPRQ) { } } } -} \ No newline at end of file +} diff --git a/include/knowhere/index/index_table.h b/include/knowhere/index/index_table.h index cda09e960..6b5fa4d97 100644 --- a/include/knowhere/index/index_table.h +++ b/include/knowhere/index/index_table.h @@ -68,6 +68,8 @@ static std::set> legal_knowhere_index = { {IndexEnum::INDEX_HNSW_SQ8_REFINE, VecType::VECTOR_BFLOAT16}, // faiss hnsw {IndexEnum::INDEX_FAISS_HNSW_FLAT, VecType::VECTOR_FLOAT}, + {IndexEnum::INDEX_FAISS_HNSW_FLAT, VecType::VECTOR_FLOAT16}, + {IndexEnum::INDEX_FAISS_HNSW_FLAT, VecType::VECTOR_BFLOAT16}, {IndexEnum::INDEX_FAISS_HNSW_SQ, VecType::VECTOR_FLOAT}, diff --git a/include/knowhere/utils.h b/include/knowhere/utils.h index 1e9775fbe..f1de7adfd 100644 --- a/include/knowhere/utils.h +++ b/include/knowhere/utils.h @@ -109,18 +109,34 @@ GetKey(const std::string& name) { template inline DataSetPtr -data_type_conversion(const DataSet& src) { +data_type_conversion( + const DataSet& src, + const std::optional start = std::nullopt, + const std::optional count = std::nullopt +) { auto dim = src.GetDim(); auto rows = src.GetRows(); - auto des_data = new OutType[dim * rows]; - auto src_data = (InType*)src.GetTensor(); - for (auto i = 0; i < dim * rows; i++) { - des_data[i] = (OutType)src_data[i]; + // check the acceptable range + int64_t start_row = start.value_or(0); + if (start_row < 0 || start_row >= rows) { + return nullptr; + } + + int64_t count_rows = count.value_or(rows - start_row); + if (count_rows < 0 || start_row + count_rows > rows) { + return nullptr; + } + + // map + auto* des_data = new OutType[dim * count_rows]; + auto* src_data = (const InType*)src.GetTensor(); + for (auto i = 0; i < dim * count_rows; i++) { + des_data[i] = (OutType)src_data[i + start_row * dim]; } auto des = std::make_shared(); - des->SetRows(rows); + des->SetRows(count_rows); des->SetDim(dim); des->SetTensor(des_data); des->SetIsOwner(true); @@ -128,25 +144,36 @@ data_type_conversion(const DataSet& src) { } // Convert DataSet from DataType to float +// * no start, no count, float -> returns the source without cloning +// * no start, no count, no float -> returns a clone with a different type +// * start, no count -> returns a clone that starts from a given row 'start' +// * no start, count -> returns a clone that starts from a row 0 and has 'count' rows +// * start, count -> returns a clone that start from a given row 'start' and has 'count' rows +// * invalid start, count values -> returns nullptr template inline DataSetPtr -ConvertFromDataTypeIfNeeded(const DataSetPtr ds) { +ConvertFromDataTypeIfNeeded(const DataSetPtr& ds, const std::optional start = std::nullopt, const std::optional count = std::nullopt) { if constexpr (std::is_same_v::type>) { - return ds; - } else { - return data_type_conversion::type>(*ds); + if (!start.has_value() && !count.has_value()) { + return ds; + } } + + return data_type_conversion::type>(*ds, start, count); } + // Convert DataSet from float to DataType template inline DataSetPtr -ConvertToDataTypeIfNeeded(const DataSetPtr ds) { +ConvertToDataTypeIfNeeded(const DataSetPtr& ds, const std::optional start = std::nullopt, const std::optional count = std::nullopt) { if constexpr (std::is_same_v::type>) { - return ds; - } else { - return data_type_conversion::type, DataType>(*ds); - } + if (!start.has_value() && !count.has_value()) { + return ds; + } + } + + return data_type_conversion::type, DataType>(*ds, start, count); } template diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 3e48e6411..eab4be244 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -279,11 +279,70 @@ class BaseFaissRegularIndexNode : public BaseFaissIndexNode { } }; +// +enum class DataFormatEnum { + fp32, + fp16, + bf16 +}; + +template +struct DataType2EnumHelper {}; + +template<> struct DataType2EnumHelper { + static constexpr DataFormatEnum value = DataFormatEnum::fp32; +}; +template<> struct DataType2EnumHelper { + static constexpr DataFormatEnum value = DataFormatEnum::fp16; +}; +template<> struct DataType2EnumHelper { + static constexpr DataFormatEnum value = DataFormatEnum::bf16; +}; + +template +static constexpr DataFormatEnum datatype_v = DataType2EnumHelper::value; + + + +// +namespace { + +// +bool convert_rows( + const void* const __restrict src_in, + float* const __restrict dst, + const DataFormatEnum data_format, + const size_t start_row, + const size_t nrows, + const size_t dim +) { + if (data_format == DataFormatEnum::fp16) { + const knowhere::fp16* const src = reinterpret_cast(src_in); + for (size_t i = 0; i < nrows * dim; i++) { + dst[i] = (float)(src[i + start_row * dim]); + } + + return true; + } else if (data_format == DataFormatEnum::bf16) { + const knowhere::bf16* const src = reinterpret_cast(src_in); + for (size_t i = 0; i < nrows * dim; i++) { + dst[i] = (float)(src[i + start_row * dim]); + } + + return true; + } else { + // unknown + return false; + } +} + +} + // class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { public: - BaseFaissRegularIndexHNSWNode(const int32_t& version, const Object& object) - : BaseFaissRegularIndexNode(version, object) { + BaseFaissRegularIndexHNSWNode(const int32_t& version, const Object& object, DataFormatEnum data_format_in) + : BaseFaissRegularIndexNode(version, object), data_format{data_format_in} { } bool @@ -350,7 +409,18 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { ThreadPool::ScopedOmpSetter setter(1); // set up a query - const float* cur_query = (const float*)data + idx * dim; + // const float* cur_query = (const float*)data + idx * dim; + + const float* cur_query = nullptr; + + std::vector cur_query_tmp(dim); + if (data_format == DataFormatEnum::fp32) { + cur_query = (const float*)data + idx * dim; + } else { + convert_rows(data, cur_query_tmp.data(), data_format, idx, 1, dim); + cur_query = cur_query_tmp.data(); + } + // set up local results faiss::idx_t* const __restrict local_ids = ids.get() + k * idx; @@ -506,6 +576,8 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { } protected: + DataFormatEnum data_format; + // Decides whether a brute force should be used instead of a regular HNSW search. // This may be applicable in case of very large topk values or // extremely high filtering levels. @@ -560,11 +632,10 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { }; // -template class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode { public: - BaseFaissRegularIndexHNSWFlatNode(const int32_t& version, const Object& object) - : BaseFaissRegularIndexHNSWNode(version, object) { + BaseFaissRegularIndexHNSWFlatNode(const int32_t& version, const Object& object, DataFormatEnum data_format) + : BaseFaissRegularIndexHNSWNode(version, object, data_format) { } bool @@ -611,9 +682,31 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode { std::unique_ptr hnsw_index; if (is_cosine) { - hnsw_index = std::make_unique(dim, hnsw_cfg.M.value()); + if (data_format == DataFormatEnum::fp32) { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value()); + } else if (data_format == DataFormatEnum::fp16) { + hnsw_index = std::make_unique( + dim, faiss::ScalarQuantizer::QT_fp16, hnsw_cfg.M.value()); + } else if (data_format == DataFormatEnum::bf16) { + hnsw_index = std::make_unique( + dim, faiss::ScalarQuantizer::QT_bf16, hnsw_cfg.M.value()); + } else { + LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value(); + return Status::invalid_metric_type; + } } else { - hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); + if (data_format == DataFormatEnum::fp32) { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); + } else if (data_format == DataFormatEnum::fp16) { + hnsw_index = std::make_unique( + dim, faiss::ScalarQuantizer::QT_fp16, hnsw_cfg.M.value(), metric.value()); + } else if (data_format == DataFormatEnum::bf16) { + hnsw_index = std::make_unique( + dim, faiss::ScalarQuantizer::QT_bf16, hnsw_cfg.M.value(), metric.value()); + } else { + LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value(); + return Status::invalid_metric_type; + } } hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); @@ -621,12 +714,68 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode { // train LOG_KNOWHERE_INFO_ << "Training HNSW Index"; + // this function does nothing for the given parameters and indices. + // as a result, I'm just keeping it to have is_trained set to true. + // WARNING: this may cause problems if ->train() performs some action + // based on the data in the future. Otherwise, data needs to be + // converted into float*. hnsw_index->train(rows, (const float*)data); // done index = std::move(hnsw_index); return Status::success; } + + Status + AddInternal(const DataSetPtr dataset, const Config&) override { + if (index == nullptr) { + LOG_KNOWHERE_ERROR_ << "Can not add data to an empty index."; + return Status::empty_index; + } + + const auto* data = dataset->GetTensor(); + auto rows = dataset->GetRows(); + auto dim = dataset->GetDim(); + try { + LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; + + if (data_format == DataFormatEnum::fp32) { + // add as is + index->add(rows, reinterpret_cast(data)); + } else { + // convert data into float in pieces and add to the index + constexpr int64_t n_tmp_rows = 65536; + std::vector tmp(n_tmp_rows * dim); + + for (int64_t irow = 0; irow < rows; irow += n_tmp_rows) { + const int64_t start_row = irow; + const int64_t end_row = std::min(rows, start_row + n_tmp_rows); + const int64_t count_rows = end_row - start_row; + + if (!convert_rows(data, tmp.data(), data_format, start_row, count_rows, dim)) { + LOG_KNOWHERE_ERROR_ << "Unsupported data format"; + return Status::invalid_args; + } + + // add + index->add(count_rows, tmp.data()); + } + } + } catch (const std::exception& e) { + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); + return Status::faiss_inner_error; + } + + return Status::success; + } +}; + +template +class BaseFaissRegularIndexHNSWFlatNodeTemplate : public BaseFaissRegularIndexHNSWFlatNode { + public: + BaseFaissRegularIndexHNSWFlatNodeTemplate(const int32_t& version, const Object& object) + : BaseFaissRegularIndexHNSWFlatNode(version, object, datatype_v) { + } }; @@ -736,11 +885,10 @@ pick_refine_index(const std::optional& refine_type, std::unique_ptr } // namespace // -template class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { public: - BaseFaissRegularIndexHNSWSQNode(const int32_t& version, const Object& object) : - BaseFaissRegularIndexHNSWNode(version, object) {} + BaseFaissRegularIndexHNSWSQNode(const int32_t& version, const Object& object, DataFormatEnum data_format) : + BaseFaissRegularIndexHNSWNode(version, object, data_format) {} std::unique_ptr CreateConfig() const override { @@ -820,13 +968,19 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { } }; +template +class BaseFaissRegularIndexHNSWSQNodeTemplate : public BaseFaissRegularIndexHNSWSQNode { + public: + BaseFaissRegularIndexHNSWSQNodeTemplate(const int32_t& version, const Object& object) : + BaseFaissRegularIndexHNSWSQNode(version, object, datatype_v) {} +}; + // this index trains PQ and HNSW+FLAT separately, then constructs HNSW+PQ -template class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { public: - BaseFaissRegularIndexHNSWPQNode(const int32_t& version, const Object& object) : - BaseFaissRegularIndexHNSWNode(version, object) {} + BaseFaissRegularIndexHNSWPQNode(const int32_t& version, const Object& object, DataFormatEnum data_format) : + BaseFaissRegularIndexHNSWNode(version, object, data_format) {} std::unique_ptr CreateConfig() const override { @@ -990,13 +1144,19 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { } }; +template +class BaseFaissRegularIndexHNSWPQNodeTemplate : public BaseFaissRegularIndexHNSWPQNode { + public: + BaseFaissRegularIndexHNSWPQNodeTemplate(const int32_t& version, const Object& object) : + BaseFaissRegularIndexHNSWPQNode(version, object, datatype_v) {} +}; + // this index trains PRQ and HNSW+FLAT separately, then constructs HNSW+PRQ -template class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { public: - BaseFaissRegularIndexHNSWPRQNode(const int32_t& version, const Object& object) : - BaseFaissRegularIndexHNSWNode(version, object) {} + BaseFaissRegularIndexHNSWPRQNode(const int32_t& version, const Object& object, DataFormatEnum data_format) : + BaseFaissRegularIndexHNSWNode(version, object, data_format) {} std::unique_ptr CreateConfig() const override { @@ -1165,13 +1325,23 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { } }; +template +class BaseFaissRegularIndexHNSWPRQNodeTemplate : public BaseFaissRegularIndexHNSWPRQNode { + public: + BaseFaissRegularIndexHNSWPRQNodeTemplate(const int32_t& version, const Object& object) : + BaseFaissRegularIndexHNSWPRQNode(version, object, datatype_v) {} +}; -KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_FLAT, BaseFaissRegularIndexHNSWFlatNode, fp32); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_SQ, BaseFaissRegularIndexHNSWSQNode, fp32); +// +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_FLAT, BaseFaissRegularIndexHNSWFlatNodeTemplate, fp32); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_FLAT, BaseFaissRegularIndexHNSWFlatNodeTemplate, fp16); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_FLAT, BaseFaissRegularIndexHNSWFlatNodeTemplate, bf16); + +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate, fp32); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PQ, BaseFaissRegularIndexHNSWPQNode, fp32); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate, fp32); -KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNode, fp32); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate, fp32); } // namespace knowhere From 219b1494bc9b55add2d1505664784c74ab7c5e65 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Tue, 6 Aug 2024 10:23:24 -0400 Subject: [PATCH 09/21] fix clang-format Signed-off-by: Alexandr Guzhva --- benchmark/hdf5/benchmark_faiss_hnsw.cpp | 32 +++++++-------- benchmark/hdf5/benchmark_knowhere.h | 26 ++++++------- include/knowhere/index/index_table.h | 6 +-- include/knowhere/utils.h | 6 +-- src/index/hnsw/faiss_hnsw.cc | 52 ++++++++++++------------- src/index/hnsw/faiss_hnsw_config.h | 4 +- 6 files changed, 63 insertions(+), 63 deletions(-) diff --git a/benchmark/hdf5/benchmark_faiss_hnsw.cpp b/benchmark/hdf5/benchmark_faiss_hnsw.cpp index 625e54f36..fd8bf864b 100644 --- a/benchmark/hdf5/benchmark_faiss_hnsw.cpp +++ b/benchmark/hdf5/benchmark_faiss_hnsw.cpp @@ -81,15 +81,15 @@ class Benchmark_Faiss_Hnsw : public Benchmark_knowhere, public ::testing::Test { // calc recall auto recall = this->CalcRecall( - golden_result->GetIds(), - result.value()->GetIds(), + golden_result->GetIds(), + result.value()->GetIds(), query_t_ds_ptr->GetRows(), conf[knowhere::meta::TOPK].get() ); auto recall_loaded = this->CalcRecall( - golden_result->GetIds(), - result_loaded.value()->GetIds(), + golden_result->GetIds(), + result_loaded.value()->GetIds(), query_t_ds_ptr->GetRows(), conf[knowhere::meta::TOPK].get() ); @@ -106,7 +106,7 @@ class Benchmark_Faiss_Hnsw : public Benchmark_knowhere, public ::testing::Test { T0_ = elapsed(); set_ann_test_name("faiss_hnsw"); - knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2); + knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2); cfg_[knowhere::indexparam::HNSW_M] = 16; cfg_[knowhere::indexparam::EFCONSTRUCTION] = 96; @@ -119,7 +119,7 @@ class Benchmark_Faiss_Hnsw : public Benchmark_knowhere, public ::testing::Test { void CreateGoldenIndices() { const std::string golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP; - + uint64_t rng_seed = 1; for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) { for (const int32_t dim : DIMS) { @@ -201,7 +201,7 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWFLAT) { ann_test_name_, golden_index_type, params); auto golden_index = create_index( - golden_index_type, + golden_index_type, golden_index_file_name, default_ds_ptr, conf, @@ -289,7 +289,7 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWSQ) { ann_test_name_, golden_index_type, params); auto golden_index = create_index( - golden_index_type, + golden_index_type, golden_index_file_name, default_ds_ptr, conf, @@ -317,12 +317,12 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWSQ) { // test refines const auto& allowed_refs = SQ_ALLOWED_REFINES[SQ_TYPES[sq_type]]; for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); allowed_ref_idx++) { - auto conf_refine = conf; + auto conf_refine = conf; conf_refine["refine"] = true; conf_refine["refine_k"] = 1.5; conf_refine["refine_type"] = allowed_refs[allowed_ref_idx]; - std::vector params_refine = + std::vector params_refine = {(int)distance_type, dim, nb, (int)sq_type, (int)allowed_ref_idx}; // @@ -382,7 +382,7 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWPQ) { ann_test_name_, golden_index_type, params); auto golden_index = create_index( - golden_index_type, + golden_index_type, golden_index_file_name, default_ds_ptr, conf, @@ -410,12 +410,12 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWPQ) { // test refines for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES.size(); allowed_ref_idx++) { - auto conf_refine = conf; + auto conf_refine = conf; conf_refine["refine"] = true; conf_refine["refine_k"] = 1.5; conf_refine["refine_type"] = PQ_ALLOWED_REFINES[allowed_ref_idx]; - std::vector params_refine = + std::vector params_refine = {(int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; // @@ -478,7 +478,7 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWPRQ) { ann_test_name_, golden_index_type, params); auto golden_index = create_index( - golden_index_type, + golden_index_type, golden_index_file_name, default_ds_ptr, conf, @@ -507,12 +507,12 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWPRQ) { // test refines for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES.size(); allowed_ref_idx++) { - auto conf_refine = conf; + auto conf_refine = conf; conf_refine["refine"] = true; conf_refine["refine_k"] = 1.5; conf_refine["refine_type"] = PQ_ALLOWED_REFINES[allowed_ref_idx]; - std::vector params_refine = + std::vector params_refine = {(int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, (int)allowed_ref_idx}; // diff --git a/benchmark/hdf5/benchmark_knowhere.h b/benchmark/hdf5/benchmark_knowhere.h index 6cff6f5e7..fcfd4c250 100644 --- a/benchmark/hdf5/benchmark_knowhere.h +++ b/benchmark/hdf5/benchmark_knowhere.h @@ -121,41 +121,41 @@ class Benchmark_knowhere : public Benchmark_hdf5 { template knowhere::Index create_index( - const std::string& index_type, - const std::string& index_file_name, + const std::string& index_type, + const std::string& index_file_name, const knowhere::DataSetPtr& default_ds_ptr, const knowhere::Json& conf, const std::optional& additional_name = std::nullopt ) { std::string additional_name_s = additional_name.value_or(""); - printf("[%.3f s] Creating %sindex \"%s\"\n", - get_time_diff(), - additional_name_s.c_str(), + printf("[%.3f s] Creating %sindex \"%s\"\n", + get_time_diff(), + additional_name_s.c_str(), index_type.c_str()); auto version = knowhere::Version::GetCurrentVersion().VersionNumber(); auto index = knowhere::IndexFactory::Instance().Create(index_type, version); try { - printf("[%.3f s] Reading %sindex file: %s\n", - get_time_diff(), + printf("[%.3f s] Reading %sindex file: %s\n", + get_time_diff(), additional_name_s.c_str(), index_file_name.c_str()); read_index(index.value(), index_file_name, conf); } catch (...) { - printf("[%.3f s] Building %sindex all on %ld vectors\n", - get_time_diff(), - additional_name_s.c_str(), + printf("[%.3f s] Building %sindex all on %ld vectors\n", + get_time_diff(), + additional_name_s.c_str(), default_ds_ptr->GetRows()); auto base = knowhere::ConvertToDataTypeIfNeeded(default_ds_ptr); index.value().Build(base, conf); - printf("[%.3f s] Writing %sindex file: %s\n", - get_time_diff(), - additional_name_s.c_str(), + printf("[%.3f s] Writing %sindex file: %s\n", + get_time_diff(), + additional_name_s.c_str(), index_file_name.c_str()); write_index(index.value(), index_file_name, conf); diff --git a/include/knowhere/index/index_table.h b/include/knowhere/index/index_table.h index 6b5fa4d97..2eed19724 100644 --- a/include/knowhere/index/index_table.h +++ b/include/knowhere/index/index_table.h @@ -130,11 +130,11 @@ static std::set legal_support_mmap_knowhere_index = { // faiss hnsw IndexEnum::INDEX_FAISS_HNSW_FLAT, - IndexEnum::INDEX_FAISS_HNSW_SQ, + IndexEnum::INDEX_FAISS_HNSW_SQ, - IndexEnum::INDEX_FAISS_HNSW_PQ, + IndexEnum::INDEX_FAISS_HNSW_PQ, - IndexEnum::INDEX_FAISS_HNSW_PRQ, + IndexEnum::INDEX_FAISS_HNSW_PRQ, // sparse index IndexEnum::INDEX_SPARSE_INVERTED_INDEX, diff --git a/include/knowhere/utils.h b/include/knowhere/utils.h index f1de7adfd..72475d771 100644 --- a/include/knowhere/utils.h +++ b/include/knowhere/utils.h @@ -111,7 +111,7 @@ template inline DataSetPtr data_type_conversion( const DataSet& src, - const std::optional start = std::nullopt, + const std::optional start = std::nullopt, const std::optional count = std::nullopt ) { auto dim = src.GetDim(); @@ -171,8 +171,8 @@ ConvertToDataTypeIfNeeded(const DataSetPtr& ds, const std::optional sta if (!start.has_value() && !count.has_value()) { return ds; } - } - + } + return data_type_conversion::type, DataType>(*ds, start, count); } diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index eab4be244..a190b007d 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -289,14 +289,14 @@ enum class DataFormatEnum { template struct DataType2EnumHelper {}; -template<> struct DataType2EnumHelper { - static constexpr DataFormatEnum value = DataFormatEnum::fp32; +template<> struct DataType2EnumHelper { + static constexpr DataFormatEnum value = DataFormatEnum::fp32; }; -template<> struct DataType2EnumHelper { - static constexpr DataFormatEnum value = DataFormatEnum::fp16; +template<> struct DataType2EnumHelper { + static constexpr DataFormatEnum value = DataFormatEnum::fp16; }; -template<> struct DataType2EnumHelper { - static constexpr DataFormatEnum value = DataFormatEnum::bf16; +template<> struct DataType2EnumHelper { + static constexpr DataFormatEnum value = DataFormatEnum::bf16; }; template @@ -412,7 +412,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { // const float* cur_query = (const float*)data + idx * dim; const float* cur_query = nullptr; - + std::vector cur_query_tmp(dim); if (data_format == DataFormatEnum::fp32) { cur_query = (const float*)data + idx * dim; @@ -696,7 +696,7 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode { } } else { if (data_format == DataFormatEnum::fp32) { - hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); } else if (data_format == DataFormatEnum::fp16) { hnsw_index = std::make_unique( dim, faiss::ScalarQuantizer::QT_fp16, hnsw_cfg.M.value(), metric.value()); @@ -717,8 +717,8 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode { // this function does nothing for the given parameters and indices. // as a result, I'm just keeping it to have is_trained set to true. // WARNING: this may cause problems if ->train() performs some action - // based on the data in the future. Otherwise, data needs to be - // converted into float*. + // based on the data in the future. Otherwise, data needs to be + // converted into float*. hnsw_index->train(rows, (const float*)data); // done @@ -807,7 +807,7 @@ is_flat_refine(const std::optional& refine_type) { if (!refine_type.has_value()) { return true; }; - + // todo: tolower if (refine_type.value() == "FP32" || refine_type.value() == "FLAT") { return true; @@ -820,7 +820,7 @@ is_flat_refine(const std::optional& refine_type) { return expected::Err( Status::invalid_args, fmt::format("invalid refine type ({})", refine_type.value())); } - + return false; } @@ -844,7 +844,7 @@ pick_refine_index(const std::optional& refine_type, std::unique_ptr if (is_fp32_flat_v) { // build IndexFlat as a refine auto refine_index = std::make_unique(local_hnsw_index.get()); - + // let refine_index to own everything refine_index->own_fields = true; local_hnsw_index.release(); @@ -864,8 +864,8 @@ pick_refine_index(const std::optional& refine_type, std::unique_ptr // create an sq auto sq_refine = std::make_unique( - local_hnsw_index->storage->d, - refine_sq_type.value(), + local_hnsw_index->storage->d, + refine_sq_type.value(), local_hnsw_index->storage->metric_type ); @@ -923,7 +923,7 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { if (!sq_type.has_value()) { LOG_KNOWHERE_ERROR_ << "Invalid scalar quantizer type: " << hnsw_cfg.sq_type.value(); return Status::invalid_args; - } + } // create an index const bool is_cosine = IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE); @@ -946,7 +946,7 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { auto final_index_cnd = pick_refine_index(hnsw_cfg.refine_type, std::move(hnsw_index)); if (!final_index_cnd.has_value()) { return Status::invalid_args; - } + } // assign final_index = std::move(final_index_cnd.value()); @@ -1043,7 +1043,7 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { auto final_index_cnd = pick_refine_index(hnsw_cfg.refine_type, std::move(hnsw_index)); if (!final_index_cnd.has_value()) { return Status::invalid_args; - } + } // assign final_index = std::move(final_index_cnd.value()); @@ -1097,7 +1097,7 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { // check if we have a refine available. faiss::IndexHNSW* index_hnsw = nullptr; - faiss::IndexRefine* const index_refine = + faiss::IndexRefine* const index_refine = dynamic_cast(index.get()); if (index_refine != nullptr) { @@ -1107,7 +1107,7 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { } // recreate hnswpq - std::unique_ptr index_hnsw_pq; + std::unique_ptr index_hnsw_pq; if (index_hnsw->storage->is_cosine) { index_hnsw_pq = std::make_unique(); @@ -1116,7 +1116,7 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { } // C++ slicing - static_cast(*index_hnsw_pq) = + static_cast(*index_hnsw_pq) = std::move(static_cast(*index_hnsw)); // clear out the storage @@ -1204,7 +1204,7 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); // prq - faiss::AdditiveQuantizer::Search_type_t prq_search_type = + faiss::AdditiveQuantizer::Search_type_t prq_search_type = (metric.value() == faiss::MetricType::METRIC_INNER_PRODUCT) ? faiss::AdditiveQuantizer::Search_type_t::ST_LUT_nonorm : faiss::AdditiveQuantizer::Search_type_t::ST_norm_float; @@ -1225,7 +1225,7 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { auto final_index_cnd = pick_refine_index(hnsw_cfg.refine_type, std::move(hnsw_index)); if (!final_index_cnd.has_value()) { return Status::invalid_args; - } + } // assign final_index = std::move(final_index_cnd.value()); @@ -1278,7 +1278,7 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { // check if we have a refine available. faiss::IndexHNSW* index_hnsw = nullptr; - faiss::IndexRefine* const index_refine = + faiss::IndexRefine* const index_refine = dynamic_cast(index.get()); if (index_refine != nullptr) { @@ -1288,7 +1288,7 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { } // recreate hnswprq - std::unique_ptr index_hnsw_prq; + std::unique_ptr index_hnsw_prq; if (index_hnsw->storage->is_cosine) { index_hnsw_prq = std::make_unique(); @@ -1297,7 +1297,7 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { } // C++ slicing - static_cast(*index_hnsw_prq) = + static_cast(*index_hnsw_prq) = std::move(static_cast(*index_hnsw)); // clear out the storage diff --git a/src/index/hnsw/faiss_hnsw_config.h b/src/index/hnsw/faiss_hnsw_config.h index b88dc1ea0..0de816cc5 100644 --- a/src/index/hnsw/faiss_hnsw_config.h +++ b/src/index/hnsw/faiss_hnsw_config.h @@ -119,7 +119,7 @@ class FaissHnswConfig : public BaseConfig { // 'flat' is identical to 'fp32' std::vector allowed_list = { "SQ6", "SQ8", "FP16", "BF16", "FP32", "FLAT"}; - + // todo: tolower() for (const auto& allowed : allowed_list) { @@ -207,7 +207,7 @@ class FaissHnswSqConfig : public FaissHnswConfig { // todo: add more std::vector allowed_list = { "SQ6", "SQ8", "FP16", "BF16"}; - + // todo: tolower() for (const auto& allowed : allowed_list) { From 1ba65004c8d931865356aa8632793aa014017882 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Tue, 6 Aug 2024 11:28:54 -0400 Subject: [PATCH 10/21] add fp16 / fb16 support for INDEX_FAISS_SQ Signed-off-by: Alexandr Guzhva --- benchmark/hdf5/benchmark_faiss_hnsw.cpp | 209 +++++++++++++++--------- include/knowhere/index/index_table.h | 2 + src/index/hnsw/faiss_hnsw.cc | 133 +++++++++------ 3 files changed, 213 insertions(+), 131 deletions(-) diff --git a/benchmark/hdf5/benchmark_faiss_hnsw.cpp b/benchmark/hdf5/benchmark_faiss_hnsw.cpp index fd8bf864b..4e05b8883 100644 --- a/benchmark/hdf5/benchmark_faiss_hnsw.cpp +++ b/benchmark/hdf5/benchmark_faiss_hnsw.cpp @@ -161,13 +161,31 @@ class Benchmark_Faiss_Hnsw : public Benchmark_knowhere, public ::testing::Test { // const std::vector NBITS = {8, 10, 12}; const std::vector NBITS = {8}; - std::unordered_map> SQ_ALLOWED_REFINES = { + // accepted refines for a given SQ type for a FP32 data type + std::unordered_map> SQ_ALLOWED_REFINES_FP32 = { { "SQ6", {"SQ8", "BF16", "FP16", "FLAT"} }, { "SQ8", {"BF16", "FP16", "FLAT"} }, { "BF16", {"FLAT"} }, { "FP16", {"FLAT"} } }; + // accepted refines for a given SQ type for a FP16 data type + std::unordered_map> SQ_ALLOWED_REFINES_FP16 = { + { "SQ6", {"SQ8", "FP16"} }, + { "SQ8", {"FP16"} }, + { "BF16", {} }, + { "FP16", {} } + }; + + // accepted refines for a given SQ type for a BF16 data type + std::unordered_map> SQ_ALLOWED_REFINES_BF16 = { + { "SQ6", {"SQ8", "BF16"} }, + { "SQ8", {"BF16"} }, + { "BF16", {} }, + { "FP16", {} } + }; + + std::vector PQ_ALLOWED_REFINES = { {"SQ6", "SQ8", "BF16", "FP16", "FLAT"} }; @@ -210,50 +228,27 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWFLAT) { auto golden_result = golden_index.Search(query_ds_ptr, conf, nullptr); - // fp32 - printf("\n"); - printf("Processing HNSW,Flat fp32 for %s distance, dim=%d, nrows=%d\n", - DISTANCE_TYPES[distance_type].c_str(), - dim, - nb); - // test a candidate + // fp32 candidate + printf("\nProcessing HNSW,Flat fp32 for %s distance, dim=%d, nrows=%d\n", + DISTANCE_TYPES[distance_type].c_str(), dim, nb); + test_hnsw( - default_ds_ptr, - query_ds_ptr, - golden_result.value(), - params, - conf); - - // fp16 - printf("\n"); - printf("Processing HNSW,Flat fp16 for %s distance, dim=%d, nrows=%d\n", - DISTANCE_TYPES[distance_type].c_str(), - dim, - nb); - - // test a candidate + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + + // fp16 candidate + printf("\nProcessing HNSW,Flat fp16 for %s distance, dim=%d, nrows=%d\n", + DISTANCE_TYPES[distance_type].c_str(), dim, nb); + test_hnsw( - default_ds_ptr, - query_ds_ptr, - golden_result.value(), - params, - conf); - - // bf32 - printf("\n"); - printf("Processing HNSW,Flat bf16 for %s distance, dim=%d, nrows=%d\n", - DISTANCE_TYPES[distance_type].c_str(), - dim, - nb); - - // test a candidate + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + + // bf32 candidate + printf("\nProcessing HNSW,Flat bf16 for %s distance, dim=%d, nrows=%d\n", + DISTANCE_TYPES[distance_type].c_str(), dim, nb); + test_hnsw( - default_ds_ptr, - query_ds_ptr, - golden_result.value(), - params, - conf); + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); } } } @@ -298,49 +293,102 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWSQ) { auto golden_result = golden_index.Search(query_ds_ptr, conf, nullptr); - // - printf("\n"); - printf("Processing HNSW,SQ(%s) fp32 for %s distance, dim=%d, nrows=%d\n", - SQ_TYPES[sq_type].c_str(), - DISTANCE_TYPES[distance_type].c_str(), - dim, - nb); - // test a candidate - test_hnsw( - default_ds_ptr, - query_ds_ptr, - golden_result.value(), - params, - conf); - - // test refines - const auto& allowed_refs = SQ_ALLOWED_REFINES[SQ_TYPES[sq_type]]; - for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; - conf_refine["refine_type"] = allowed_refs[allowed_ref_idx]; + // fp32 candidate + printf("\nProcessing HNSW,SQ(%s) fp32 for %s distance, dim=%d, nrows=%d\n", + SQ_TYPES[sq_type].c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb); - std::vector params_refine = - {(int)distance_type, dim, nb, (int)sq_type, (int)allowed_ref_idx}; + test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + + // fp16 candidate + printf("\nProcessing HNSW,SQ(%s) fp16 for %s distance, dim=%d, nrows=%d\n", + SQ_TYPES[sq_type].c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb); + + test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + + // bf16 candidate + printf("\nProcessing HNSW,SQ(%s) bf16 for %s distance, dim=%d, nrows=%d\n", + SQ_TYPES[sq_type].c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb); + + test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + + + // test refines for FP32 + { + const auto& allowed_refs = SQ_ALLOWED_REFINES_FP32[SQ_TYPES[sq_type]]; + for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; + conf_refine["refine_type"] = allowed_refs[allowed_ref_idx]; + + std::vector params_refine = + {(int)distance_type, dim, nb, (int)sq_type, (int)allowed_ref_idx}; + + // fp32 candidate + printf("\nProcessing HNSW,SQ(%s) with %s refine, fp32 for %s distance, dim=%d, nrows=%d\n", + SQ_TYPES[sq_type].c_str(), + allowed_refs[allowed_ref_idx].c_str(), + DISTANCE_TYPES[distance_type].c_str(), + dim, + nb); + + test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, conf_refine); + } + } - // - printf("\n"); - printf("Processing HNSW,SQ(%s) with %s refine, fp32 for %s distance, dim=%d, nrows=%d\n", - SQ_TYPES[sq_type].c_str(), - allowed_refs[allowed_ref_idx].c_str(), - DISTANCE_TYPES[distance_type].c_str(), - dim, - nb); + // test refines for FP16 + { + const auto& allowed_refs = SQ_ALLOWED_REFINES_FP16[SQ_TYPES[sq_type]]; + for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; + conf_refine["refine_type"] = allowed_refs[allowed_ref_idx]; + + std::vector params_refine = + {(int)distance_type, dim, nb, (int)sq_type, (int)allowed_ref_idx}; + + // fp16 candidate + printf("\nProcessing HNSW,SQ(%s) with %s refine, fp16 for %s distance, dim=%d, nrows=%d\n", + SQ_TYPES[sq_type].c_str(), + allowed_refs[allowed_ref_idx].c_str(), + DISTANCE_TYPES[distance_type].c_str(), + dim, + nb); + + test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, conf_refine); + } + } - // test a candidate - test_hnsw( - default_ds_ptr, - query_ds_ptr, - golden_result.value(), - params_refine, - conf_refine); + // test refines for BF16 + { + const auto& allowed_refs = SQ_ALLOWED_REFINES_BF16[SQ_TYPES[sq_type]]; + for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; + conf_refine["refine_type"] = allowed_refs[allowed_ref_idx]; + + std::vector params_refine = + {(int)distance_type, dim, nb, (int)sq_type, (int)allowed_ref_idx}; + + // bf16 candidate + printf("\nProcessing HNSW,SQ(%s) with %s refine, bf16 for %s distance, dim=%d, nrows=%d\n", + SQ_TYPES[sq_type].c_str(), + allowed_refs[allowed_ref_idx].c_str(), + DISTANCE_TYPES[distance_type].c_str(), + dim, + nb); + + test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, conf_refine); + } } } } @@ -348,7 +396,6 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWSQ) { } } - // TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWPQ) { const std::string& index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_PQ; diff --git a/include/knowhere/index/index_table.h b/include/knowhere/index/index_table.h index 2eed19724..100671e55 100644 --- a/include/knowhere/index/index_table.h +++ b/include/knowhere/index/index_table.h @@ -72,6 +72,8 @@ static std::set> legal_knowhere_index = { {IndexEnum::INDEX_FAISS_HNSW_FLAT, VecType::VECTOR_BFLOAT16}, {IndexEnum::INDEX_FAISS_HNSW_SQ, VecType::VECTOR_FLOAT}, + {IndexEnum::INDEX_FAISS_HNSW_SQ, VecType::VECTOR_FLOAT16}, + {IndexEnum::INDEX_FAISS_HNSW_SQ, VecType::VECTOR_BFLOAT16}, {IndexEnum::INDEX_FAISS_HNSW_PQ, VecType::VECTOR_FLOAT}, diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index a190b007d..d2954dc28 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -336,6 +336,19 @@ bool convert_rows( } } +// +DataSetPtr convert_ds_to_float(const DataSetPtr& src, DataFormatEnum data_format) { + if (data_format == DataFormatEnum::fp32) { + return src; + } else if (data_format == DataFormatEnum::fp16) { + return ConvertFromDataTypeIfNeeded(src); + } else if (data_format == DataFormatEnum::bf16) { + return ConvertFromDataTypeIfNeeded(src); + } + + return nullptr; +} + } // @@ -616,12 +629,34 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { return Status::empty_index; } - auto data = dataset->GetTensor(); + const auto* data = dataset->GetTensor(); auto rows = dataset->GetRows(); + auto dim = dataset->GetDim(); try { LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; - index->add(rows, reinterpret_cast(data)); + if (data_format == DataFormatEnum::fp32) { + // add as is + index->add(rows, reinterpret_cast(data)); + } else { + // convert data into float in pieces and add to the index + constexpr int64_t n_tmp_rows = 65536; + std::vector tmp(n_tmp_rows * dim); + + for (int64_t irow = 0; irow < rows; irow += n_tmp_rows) { + const int64_t start_row = irow; + const int64_t end_row = std::min(rows, start_row + n_tmp_rows); + const int64_t count_rows = end_row - start_row; + + if (!convert_rows(data, tmp.data(), data_format, start_row, count_rows, dim)) { + LOG_KNOWHERE_ERROR_ << "Unsupported data format"; + return Status::invalid_args; + } + + // add + index->add(count_rows, tmp.data()); + } + } } catch (const std::exception& e) { LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return Status::faiss_inner_error; @@ -725,49 +760,6 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode { index = std::move(hnsw_index); return Status::success; } - - Status - AddInternal(const DataSetPtr dataset, const Config&) override { - if (index == nullptr) { - LOG_KNOWHERE_ERROR_ << "Can not add data to an empty index."; - return Status::empty_index; - } - - const auto* data = dataset->GetTensor(); - auto rows = dataset->GetRows(); - auto dim = dataset->GetDim(); - try { - LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; - - if (data_format == DataFormatEnum::fp32) { - // add as is - index->add(rows, reinterpret_cast(data)); - } else { - // convert data into float in pieces and add to the index - constexpr int64_t n_tmp_rows = 65536; - std::vector tmp(n_tmp_rows * dim); - - for (int64_t irow = 0; irow < rows; irow += n_tmp_rows) { - const int64_t start_row = irow; - const int64_t end_row = std::min(rows, start_row + n_tmp_rows); - const int64_t count_rows = end_row - start_row; - - if (!convert_rows(data, tmp.data(), data_format, start_row, count_rows, dim)) { - LOG_KNOWHERE_ERROR_ << "Unsupported data format"; - return Status::invalid_args; - } - - // add - index->add(count_rows, tmp.data()); - } - } - } catch (const std::exception& e) { - LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); - return Status::faiss_inner_error; - } - - return Status::success; - } }; template @@ -826,7 +818,11 @@ is_flat_refine(const std::optional& refine_type) { // pick a refine index expected> -pick_refine_index(const std::optional& refine_type, std::unique_ptr&& hnsw_index) { +pick_refine_index( + const DataFormatEnum data_format, + const std::optional& refine_type, + std::unique_ptr&& hnsw_index +) { // yes // grab a type of a refine index @@ -838,6 +834,34 @@ pick_refine_index(const std::optional& refine_type, std::unique_ptr const bool is_fp32_flat_v = is_fp32_flat.value(); + // check input data_format + if (data_format == DataFormatEnum::fp16) { + // make sure that we're using fp16 refine + auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); + if (!(refine_sq_type.has_value() && ( + refine_sq_type.value() != faiss::ScalarQuantizer::QT_bf16 && + !is_fp32_flat_v))) { + + LOG_KNOWHERE_ERROR_ << "fp16 input data does not accept bf16 or fp32 as a refine index."; + return expected>::Err( + Status::invalid_args, "fp16 input data does not accept bf16 or fp32 as a refine index."); + } + } + + if (data_format == DataFormatEnum::bf16) { + // make sure that we're using bf16 refine + auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); + if (!(refine_sq_type.has_value() && ( + refine_sq_type.value() != faiss::ScalarQuantizer::QT_fp16 && + !is_fp32_flat_v))) { + + LOG_KNOWHERE_ERROR_ << "bf16 input data does not accept fp16 or fp32 as a refine index."; + return expected>::Err( + Status::invalid_args, "bf16 input data does not accept fp16 or fp32 as a refine index."); + } + } + + // build std::unique_ptr local_hnsw_index = std::move(hnsw_index); // either build flat or sq @@ -943,7 +967,7 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { std::unique_ptr final_index; if (hnsw_cfg.refine.value_or(false)) { // yes - auto final_index_cnd = pick_refine_index(hnsw_cfg.refine_type, std::move(hnsw_index)); + auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); if (!final_index_cnd.has_value()) { return Status::invalid_args; } @@ -960,7 +984,14 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { // train LOG_KNOWHERE_INFO_ << "Training HNSW Index"; - final_index->train(rows, (const float*)data); + //we have to convert the data to float, unfortunately + auto float_ds_ptr = convert_ds_to_float(dataset, data_format); + if (float_ds_ptr == nullptr) { + LOG_KNOWHERE_ERROR_ << "Unsupported data format"; + return Status::invalid_args; + } + + final_index->train(rows, reinterpret_cast(float_ds_ptr->GetTensor())); // done index = std::move(final_index); @@ -1040,7 +1071,7 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { std::unique_ptr final_index; if (hnsw_cfg.refine.value_or(false)) { // yes - auto final_index_cnd = pick_refine_index(hnsw_cfg.refine_type, std::move(hnsw_index)); + auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); if (!final_index_cnd.has_value()) { return Status::invalid_args; } @@ -1222,7 +1253,7 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { std::unique_ptr final_index; if (hnsw_cfg.refine.value_or(false)) { // yes - auto final_index_cnd = pick_refine_index(hnsw_cfg.refine_type, std::move(hnsw_index)); + auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); if (!final_index_cnd.has_value()) { return Status::invalid_args; } @@ -1339,6 +1370,8 @@ KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_FLAT, BaseFaissRegularIndexHNSWFlatNo KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_FLAT, BaseFaissRegularIndexHNSWFlatNodeTemplate, bf16); KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate, fp32); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate, fp16); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate, bf16); KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate, fp32); From bca03ed906ed530b920273537866a75ddafba609 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Tue, 6 Aug 2024 12:01:30 -0400 Subject: [PATCH 11/21] add bf16/pq16 support to FAISS_HNSW_SQ Signed-off-by: Alexandr Guzhva --- benchmark/hdf5/benchmark_faiss_hnsw.cpp | 115 ++++++++++++++++++------ include/knowhere/index/index_table.h | 2 + src/index/hnsw/faiss_hnsw.cc | 92 ++++++++++++------- 3 files changed, 151 insertions(+), 58 deletions(-) diff --git a/benchmark/hdf5/benchmark_faiss_hnsw.cpp b/benchmark/hdf5/benchmark_faiss_hnsw.cpp index 4e05b8883..929541422 100644 --- a/benchmark/hdf5/benchmark_faiss_hnsw.cpp +++ b/benchmark/hdf5/benchmark_faiss_hnsw.cpp @@ -185,10 +185,20 @@ class Benchmark_Faiss_Hnsw : public Benchmark_knowhere, public ::testing::Test { { "FP16", {} } }; - - std::vector PQ_ALLOWED_REFINES = { + // accepted refines for PQ for FP32 data type + std::vector PQ_ALLOWED_REFINES_FP32 = { {"SQ6", "SQ8", "BF16", "FP16", "FLAT"} }; + + // accepted refines for PQ for FP16 data type + std::vector PQ_ALLOWED_REFINES_FP16 = { + {"SQ6", "SQ8", "FP16"} + }; + + // accepted refines for PQ for BF16 data type + std::vector PQ_ALLOWED_REFINES_BF16 = { + {"SQ6", "SQ8", "BF16"} + }; }; // @@ -438,50 +448,106 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWPQ) { auto golden_result = golden_index.Search(query_ds_ptr, conf, nullptr); - // - printf("\n"); - printf("Processing HNSW,PQ%dx%d fp32 for %s distance, dim=%d, nrows=%d\n", + // test fp32 candidate + printf("\nProcessing HNSW,PQ%dx%d fp32 for %s distance, dim=%d, nrows=%d\n", pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb); - // test a candidate test_hnsw( - default_ds_ptr, - query_ds_ptr, - golden_result.value(), - params, - conf); + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); - // test refines - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES.size(); allowed_ref_idx++) { + // test fp16 candidate + printf("\nProcessing HNSW,PQ%dx%d fp16 for %s distance, dim=%d, nrows=%d\n", + pq_m, + NBITS[nbits_type], + DISTANCE_TYPES[distance_type].c_str(), + dim, + nb); + + test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + + // test bf16 candidate + printf("\nProcessing HNSW,PQ%dx%d bf16 for %s distance, dim=%d, nrows=%d\n", + pq_m, + NBITS[nbits_type], + DISTANCE_TYPES[distance_type].c_str(), + dim, + nb); + + test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + + // test refines for fp32 + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); allowed_ref_idx++) { auto conf_refine = conf; conf_refine["refine"] = true; conf_refine["refine_k"] = 1.5; - conf_refine["refine_type"] = PQ_ALLOWED_REFINES[allowed_ref_idx]; + conf_refine["refine_type"] = PQ_ALLOWED_REFINES_FP32[allowed_ref_idx]; std::vector params_refine = {(int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; - // - printf("\n"); - printf("Processing HNSW,PQ%dx%d with %s refine, fp32 for %s distance, dim=%d, nrows=%d\n", + // test fp32 candidate + printf("\nProcessing HNSW,PQ%dx%d with %s refine, fp32 for %s distance, dim=%d, nrows=%d\n", pq_m, NBITS[nbits_type], - PQ_ALLOWED_REFINES[allowed_ref_idx].c_str(), + PQ_ALLOWED_REFINES_FP32[allowed_ref_idx].c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb); - // test a candidate test_hnsw( - default_ds_ptr, - query_ds_ptr, - golden_result.value(), - params_refine, - conf_refine); + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, conf_refine); + } + + // test refines for fp16 + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; + conf_refine["refine_type"] = PQ_ALLOWED_REFINES_FP16[allowed_ref_idx]; + + std::vector params_refine = + {(int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; + + // test fp16 candidate + printf("\nProcessing HNSW,PQ%dx%d with %s refine, fp16 for %s distance, dim=%d, nrows=%d\n", + pq_m, + NBITS[nbits_type], + PQ_ALLOWED_REFINES_FP16[allowed_ref_idx].c_str(), + DISTANCE_TYPES[distance_type].c_str(), + dim, + nb); + + test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, conf_refine); + } + + // test refines for bf16 + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; + conf_refine["refine_type"] = PQ_ALLOWED_REFINES_BF16[allowed_ref_idx]; + + std::vector params_refine = + {(int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; + + // test bf16 candidate + printf("\nProcessing HNSW,PQ%dx%d with %s refine, bf16 for %s distance, dim=%d, nrows=%d\n", + pq_m, + NBITS[nbits_type], + PQ_ALLOWED_REFINES_BF16[allowed_ref_idx].c_str(), + DISTANCE_TYPES[distance_type].c_str(), + dim, + nb); + + test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, conf_refine); } } } @@ -489,7 +555,6 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWPQ) { } } - // TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWPRQ) { const std::string& index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_PRQ; diff --git a/include/knowhere/index/index_table.h b/include/knowhere/index/index_table.h index 100671e55..b16051e3d 100644 --- a/include/knowhere/index/index_table.h +++ b/include/knowhere/index/index_table.h @@ -76,6 +76,8 @@ static std::set> legal_knowhere_index = { {IndexEnum::INDEX_FAISS_HNSW_SQ, VecType::VECTOR_BFLOAT16}, {IndexEnum::INDEX_FAISS_HNSW_PQ, VecType::VECTOR_FLOAT}, + {IndexEnum::INDEX_FAISS_HNSW_PQ, VecType::VECTOR_FLOAT16}, + {IndexEnum::INDEX_FAISS_HNSW_PQ, VecType::VECTOR_BFLOAT16}, {IndexEnum::INDEX_FAISS_HNSW_PRQ, VecType::VECTOR_FLOAT}, // diskann diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index d2954dc28..14416cdc5 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -349,6 +349,41 @@ DataSetPtr convert_ds_to_float(const DataSetPtr& src, DataFormatEnum data_format return nullptr; } +Status add_to_index( + faiss::Index* const __restrict index, + const DataSetPtr& dataset, + const DataFormatEnum data_format +) { + const auto* data = dataset->GetTensor(); + const auto rows = dataset->GetRows(); + const auto dim = dataset->GetDim(); + + if (data_format == DataFormatEnum::fp32) { + // add as is + index->add(rows, reinterpret_cast(data)); + } else { + // convert data into float in pieces and add to the index + constexpr int64_t n_tmp_rows = 65536; + std::vector tmp(n_tmp_rows * dim); + + for (int64_t irow = 0; irow < rows; irow += n_tmp_rows) { + const int64_t start_row = irow; + const int64_t end_row = std::min(rows, start_row + n_tmp_rows); + const int64_t count_rows = end_row - start_row; + + if (!convert_rows(data, tmp.data(), data_format, start_row, count_rows, dim)) { + LOG_KNOWHERE_ERROR_ << "Unsupported data format"; + return Status::invalid_args; + } + + // add + index->add(count_rows, tmp.data()); + } + } + + return Status::success; +} + } // @@ -629,34 +664,15 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { return Status::empty_index; } - const auto* data = dataset->GetTensor(); auto rows = dataset->GetRows(); - auto dim = dataset->GetDim(); try { LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; - if (data_format == DataFormatEnum::fp32) { - // add as is - index->add(rows, reinterpret_cast(data)); - } else { - // convert data into float in pieces and add to the index - constexpr int64_t n_tmp_rows = 65536; - std::vector tmp(n_tmp_rows * dim); - - for (int64_t irow = 0; irow < rows; irow += n_tmp_rows) { - const int64_t start_row = irow; - const int64_t end_row = std::min(rows, start_row + n_tmp_rows); - const int64_t count_rows = end_row - start_row; - - if (!convert_rows(data, tmp.data(), data_format, start_row, count_rows, dim)) { - LOG_KNOWHERE_ERROR_ << "Unsupported data format"; - return Status::invalid_args; - } - - // add - index->add(count_rows, tmp.data()); - } + auto status = add_to_index(index.get(), dataset, data_format); + if (status != Status::success) { + return status; } + } catch (const std::exception& e) { LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return Status::faiss_inner_error; @@ -930,8 +946,6 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { auto rows = dataset->GetRows(); // dimensionality of the data auto dim = dataset->GetDim(); - // data - auto data = dataset->GetTensor(); // config auto hnsw_cfg = static_cast(cfg); @@ -984,7 +998,7 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { // train LOG_KNOWHERE_INFO_ << "Training HNSW Index"; - //we have to convert the data to float, unfortunately + // we have to convert the data to float, unfortunately, which costs extra RAM auto float_ds_ptr = convert_ds_to_float(dataset, data_format); if (float_ds_ptr == nullptr) { LOG_KNOWHERE_ERROR_ << "Unsupported data format"; @@ -1031,8 +1045,6 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { auto rows = dataset->GetRows(); // dimensionality of the data auto dim = dataset->GetDim(); - // data - auto data = dataset->GetTensor(); // config auto hnsw_cfg = static_cast(cfg); @@ -1088,12 +1100,19 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { // train hnswflat LOG_KNOWHERE_INFO_ << "Training HNSW Index"; - final_index->train(rows, (const float*)data); + // we have to convert the data to float, unfortunately, which costs extra RAM + auto float_ds_ptr = convert_ds_to_float(dataset, data_format); + if (float_ds_ptr == nullptr) { + LOG_KNOWHERE_ERROR_ << "Unsupported data format"; + return Status::invalid_args; + } + + final_index->train(rows, reinterpret_cast(float_ds_ptr->GetTensor())); // train pq LOG_KNOWHERE_INFO_ << "Training PQ Index"; - pq_index->train(rows, (const float*)data); + pq_index->train(rows, reinterpret_cast(float_ds_ptr->GetTensor())); pq_index->pq.compute_sdc_table(); // done @@ -1109,18 +1128,23 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { return Status::empty_index; } - auto data = dataset->GetTensor(); auto rows = dataset->GetRows(); try { // hnsw LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; - index->add(rows, reinterpret_cast(data)); + auto status_reg = add_to_index(index.get(), dataset, data_format); + if (status_reg != Status::success) { + return status_reg; + } // pq LOG_KNOWHERE_INFO_ << "Adding " << rows << " to PQ Index"; - tmp_index_pq->add(rows, reinterpret_cast(data)); + auto status_pq = add_to_index(tmp_index_pq.get(), dataset, data_format); + if (status_pq != Status::success) { + return status_pq; + } // we're done. // throw away flat and replace it with pq @@ -1374,6 +1398,8 @@ KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTe KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate, bf16); KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate, fp32); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate, fp16); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate, bf16); KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate, fp32); From 35c2abb8bf8779c51cb61031b347f1efae9e063f Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Tue, 6 Aug 2024 12:14:37 -0400 Subject: [PATCH 12/21] add fp16 and bf16 support for FAISS_HNSW_PRQ Signed-off-by: Alexandr Guzhva --- benchmark/hdf5/benchmark_faiss_hnsw.cpp | 106 +++++++++++++++++++----- include/knowhere/index/index_table.h | 2 + src/index/hnsw/faiss_hnsw.cc | 40 ++++++--- 3 files changed, 117 insertions(+), 31 deletions(-) diff --git a/benchmark/hdf5/benchmark_faiss_hnsw.cpp b/benchmark/hdf5/benchmark_faiss_hnsw.cpp index 929541422..10e815135 100644 --- a/benchmark/hdf5/benchmark_faiss_hnsw.cpp +++ b/benchmark/hdf5/benchmark_faiss_hnsw.cpp @@ -201,6 +201,7 @@ class Benchmark_Faiss_Hnsw : public Benchmark_knowhere, public ::testing::Test { }; }; + // TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWFLAT) { const std::string& index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_FLAT; @@ -406,6 +407,7 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWSQ) { } } + // TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWPQ) { const std::string& index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_PQ; @@ -599,9 +601,8 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWPRQ) { auto golden_result = golden_index.Search(query_ds_ptr, conf, nullptr); - // - printf("\n"); - printf("Processing HNSW,PRQ%dx%dx%d fp32 for %s distance, dim=%d, nrows=%d\n", + // test fp32 candidate + printf("\nProcessing HNSW,PRQ%dx%dx%d fp32 for %s distance, dim=%d, nrows=%d\n", prq_num, prq_m, NBITS[nbits_type], @@ -609,45 +610,110 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWPRQ) { dim, nb); - // test a candidate test_hnsw( - default_ds_ptr, - query_ds_ptr, - golden_result.value(), - params, - conf); + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + + // test fp16 candidate + printf("\nProcessing HNSW,PRQ%dx%dx%d fp16 for %s distance, dim=%d, nrows=%d\n", + prq_num, + prq_m, + NBITS[nbits_type], + DISTANCE_TYPES[distance_type].c_str(), + dim, + nb); + + test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); - // test refines - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES.size(); allowed_ref_idx++) { + // test bf16 candidate + printf("\nProcessing HNSW,PRQ%dx%dx%d bf16 for %s distance, dim=%d, nrows=%d\n", + prq_num, + prq_m, + NBITS[nbits_type], + DISTANCE_TYPES[distance_type].c_str(), + dim, + nb); + + test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + + // test fp32 refines + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); allowed_ref_idx++) { auto conf_refine = conf; conf_refine["refine"] = true; conf_refine["refine_k"] = 1.5; - conf_refine["refine_type"] = PQ_ALLOWED_REFINES[allowed_ref_idx]; + conf_refine["refine_type"] = PQ_ALLOWED_REFINES_FP32[allowed_ref_idx]; std::vector params_refine = {(int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, (int)allowed_ref_idx}; // - printf("\n"); - printf("Processing HNSW,PRQ%dx%dx%d with %s refine, fp32 for %s distance, dim=%d, nrows=%d\n", + printf("\nProcessing HNSW,PRQ%dx%dx%d with %s refine, fp32 for %s distance, dim=%d, nrows=%d\n", prq_num, prq_m, NBITS[nbits_type], - PQ_ALLOWED_REFINES[allowed_ref_idx].c_str(), + PQ_ALLOWED_REFINES_FP32[allowed_ref_idx].c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb); // test a candidate test_hnsw( - default_ds_ptr, - query_ds_ptr, - golden_result.value(), - params_refine, - conf_refine); + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, conf_refine); + } + + // test fp16 refines + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; + conf_refine["refine_type"] = PQ_ALLOWED_REFINES_FP16[allowed_ref_idx]; + + std::vector params_refine = + {(int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, (int)allowed_ref_idx}; + + // + printf("\nProcessing HNSW,PRQ%dx%dx%d with %s refine, fp16 for %s distance, dim=%d, nrows=%d\n", + prq_num, + prq_m, + NBITS[nbits_type], + PQ_ALLOWED_REFINES_FP16[allowed_ref_idx].c_str(), + DISTANCE_TYPES[distance_type].c_str(), + dim, + nb); + + // test a candidate + test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, conf_refine); + } + + // test bf16 refines + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; + conf_refine["refine_type"] = PQ_ALLOWED_REFINES_BF16[allowed_ref_idx]; + + std::vector params_refine = + {(int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, (int)allowed_ref_idx}; + + // + printf("\nProcessing HNSW,PRQ%dx%dx%d with %s refine, bf16 for %s distance, dim=%d, nrows=%d\n", + prq_num, + prq_m, + NBITS[nbits_type], + PQ_ALLOWED_REFINES_BF16[allowed_ref_idx].c_str(), + DISTANCE_TYPES[distance_type].c_str(), + dim, + nb); + + // test a candidate + test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, conf_refine); } } } } } } + diff --git a/include/knowhere/index/index_table.h b/include/knowhere/index/index_table.h index b16051e3d..abb9cec3b 100644 --- a/include/knowhere/index/index_table.h +++ b/include/knowhere/index/index_table.h @@ -80,6 +80,8 @@ static std::set> legal_knowhere_index = { {IndexEnum::INDEX_FAISS_HNSW_PQ, VecType::VECTOR_BFLOAT16}, {IndexEnum::INDEX_FAISS_HNSW_PRQ, VecType::VECTOR_FLOAT}, + {IndexEnum::INDEX_FAISS_HNSW_PRQ, VecType::VECTOR_FLOAT16}, + {IndexEnum::INDEX_FAISS_HNSW_PRQ, VecType::VECTOR_BFLOAT16}, // diskann {IndexEnum::INDEX_DISKANN, VecType::VECTOR_FLOAT}, {IndexEnum::INDEX_DISKANN, VecType::VECTOR_FLOAT16}, diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 14416cdc5..6247d3dbb 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -995,9 +995,6 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { final_index = std::move(hnsw_index); } - // train - LOG_KNOWHERE_INFO_ << "Training HNSW Index"; - // we have to convert the data to float, unfortunately, which costs extra RAM auto float_ds_ptr = convert_ds_to_float(dataset, data_format); if (float_ds_ptr == nullptr) { @@ -1005,10 +1002,14 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { return Status::invalid_args; } + // train + LOG_KNOWHERE_INFO_ << "Training HNSW Index"; + final_index->train(rows, reinterpret_cast(float_ds_ptr->GetTensor())); // done index = std::move(final_index); + return Status::success; } }; @@ -1097,9 +1098,6 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { final_index = std::move(hnsw_index); } - // train hnswflat - LOG_KNOWHERE_INFO_ << "Training HNSW Index"; - // we have to convert the data to float, unfortunately, which costs extra RAM auto float_ds_ptr = convert_ds_to_float(dataset, data_format); if (float_ds_ptr == nullptr) { @@ -1107,6 +1105,9 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { return Status::invalid_args; } + // train hnswflat + LOG_KNOWHERE_INFO_ << "Training HNSW Index"; + final_index->train(rows, reinterpret_cast(float_ds_ptr->GetTensor())); // train pq @@ -1118,6 +1119,7 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { // done index = std::move(final_index); tmp_index_pq = std::move(pq_index); + return Status::success; } @@ -1291,19 +1293,27 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { final_index = std::move(hnsw_index); } + // we have to convert the data to float, unfortunately, which costs extra RAM + auto float_ds_ptr = convert_ds_to_float(dataset, data_format); + if (float_ds_ptr == nullptr) { + LOG_KNOWHERE_ERROR_ << "Unsupported data format"; + return Status::invalid_args; + } + // train hnswflat LOG_KNOWHERE_INFO_ << "Training HNSW Index"; - final_index->train(rows, (const float*)data); + final_index->train(rows, reinterpret_cast(float_ds_ptr->GetTensor())); // train prq LOG_KNOWHERE_INFO_ << "Training ProductResidualQuantizer Index"; - prq_index->train(rows, (const float*)data); + prq_index->train(rows, reinterpret_cast(float_ds_ptr->GetTensor())); // done index = std::move(final_index); tmp_index_prq = std::move(prq_index); + return Status::success; } @@ -1320,12 +1330,18 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { // hnsw LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; - index->add(rows, reinterpret_cast(data)); + auto status_reg = add_to_index(index.get(), dataset, data_format); + if (status_reg != Status::success) { + return status_reg; + } - // pq + // prq LOG_KNOWHERE_INFO_ << "Adding " << rows << " to ProductResidualQuantizer Index"; - tmp_index_prq->add(rows, reinterpret_cast(data)); + auto status_prq = add_to_index(tmp_index_prq.get(), dataset, data_format); + if (status_prq != Status::success) { + return status_prq; + } // we're done. // throw away flat and replace it with prq @@ -1402,5 +1418,7 @@ KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTe KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate, bf16); KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate, fp32); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate, fp16); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate, bf16); } // namespace knowhere From 976185ba1fc94f8b1d4b08cea077e18c50382846 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Tue, 6 Aug 2024 13:20:22 -0400 Subject: [PATCH 13/21] refactor benchmark for hnsw Signed-off-by: Alexandr Guzhva --- benchmark/benchmark_base.h | 2 +- benchmark/hdf5/benchmark_faiss_hnsw.cpp | 579 ++++++++++-------------- benchmark/hdf5/benchmark_knowhere.h | 60 +-- include/knowhere/comp/index_param.h | 4 +- include/knowhere/utils.h | 14 +- src/index/hnsw/faiss_hnsw.cc | 208 ++++----- src/index/hnsw/faiss_hnsw_config.h | 54 +-- 7 files changed, 373 insertions(+), 548 deletions(-) diff --git a/benchmark/benchmark_base.h b/benchmark/benchmark_base.h index 6acd79227..0c0592db1 100644 --- a/benchmark/benchmark_base.h +++ b/benchmark/benchmark_base.h @@ -16,8 +16,8 @@ #include #include -#include #include +#include #define CALC_TIME_SPAN(X) \ double t_start = elapsed(); \ diff --git a/benchmark/hdf5/benchmark_faiss_hnsw.cpp b/benchmark/hdf5/benchmark_faiss_hnsw.cpp index 10e815135..291ef4b94 100644 --- a/benchmark/hdf5/benchmark_faiss_hnsw.cpp +++ b/benchmark/hdf5/benchmark_faiss_hnsw.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include "benchmark_knowhere.h" #include "knowhere/comp/brute_force.h" @@ -25,7 +26,6 @@ #include "knowhere/comp/knowhere_config.h" #include "knowhere/dataset.h" - knowhere::DataSetPtr GenDataSet(int rows, int dim, const uint64_t seed = 42) { std::mt19937 rng(seed); @@ -39,40 +39,40 @@ GenDataSet(int rows, int dim, const uint64_t seed = 42) { return ds; } +uint64_t +get_params_hash(const std::vector& params) { + std::hash h; + std::hash h64; + + uint64_t result = 0; + + for (const auto value : params) { + result = h64(result ^ h(value) + 17); + } + + return result; +} + // unlike other benchmarks, this one operates on a synthetic data // and verifies the correctness of many-many variants of FAISS HNSW indices. class Benchmark_Faiss_Hnsw : public Benchmark_knowhere, public ::testing::Test { -public: - template - void test_hnsw( - const knowhere::DataSetPtr& default_ds_ptr, - const knowhere::DataSetPtr& query_ds_ptr, - const knowhere::DataSetPtr& golden_result, - const std::vector& index_params, - const knowhere::Json& conf - ) { + public: + template + void + test_hnsw(const knowhere::DataSetPtr& default_ds_ptr, const knowhere::DataSetPtr& query_ds_ptr, + const knowhere::DataSetPtr& golden_result, const std::vector& index_params, + const knowhere::Json& conf) { const std::string index_type = conf[knowhere::meta::INDEX_TYPE].get(); // load indices - std::string index_file_name = get_index_name( - ann_test_name_, index_type, index_params); + std::string index_file_name = get_index_name(ann_test_name_, index_type, index_params); // our index // first, we create an index and save it - auto index = create_index( - index_type, - index_file_name, - default_ds_ptr, - conf - ); + auto index = create_index(index_type, index_file_name, default_ds_ptr, conf); // then, we force it to be loaded in order to test load & save - auto index_loaded = create_index( - index_type, - index_file_name, - default_ds_ptr, - conf - ); + auto index_loaded = create_index(index_type, index_file_name, default_ds_ptr, conf); auto query_t_ds_ptr = knowhere::ConvertToDataTypeIfNeeded(query_ds_ptr); @@ -80,27 +80,20 @@ class Benchmark_Faiss_Hnsw : public Benchmark_knowhere, public ::testing::Test { auto result_loaded = index_loaded.Search(query_t_ds_ptr, conf, nullptr); // calc recall - auto recall = this->CalcRecall( - golden_result->GetIds(), - result.value()->GetIds(), - query_t_ds_ptr->GetRows(), - conf[knowhere::meta::TOPK].get() - ); - - auto recall_loaded = this->CalcRecall( - golden_result->GetIds(), - result_loaded.value()->GetIds(), - query_t_ds_ptr->GetRows(), - conf[knowhere::meta::TOPK].get() - ); + auto recall = this->CalcRecall(golden_result->GetIds(), result.value()->GetIds(), query_t_ds_ptr->GetRows(), + conf[knowhere::meta::TOPK].get()); + + auto recall_loaded = this->CalcRecall(golden_result->GetIds(), result_loaded.value()->GetIds(), + query_t_ds_ptr->GetRows(), conf[knowhere::meta::TOPK].get()); printf("Recall is %f, %f\n", recall, recall_loaded); - ASSERT_GE(recall, 0.9); - ASSERT_GE(recall_loaded, 0.9); + ASSERT_GE(recall, 0.8); + ASSERT_GE(recall_loaded, 0.8); + ASSERT_FLOAT_EQ(recall, recall_loaded); } -protected: + protected: void SetUp() override { T0_ = elapsed(); @@ -117,10 +110,10 @@ class Benchmark_Faiss_Hnsw : public Benchmark_knowhere, public ::testing::Test { CreateGoldenIndices(); } - void CreateGoldenIndices() { + void + CreateGoldenIndices() { const std::string golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP; - uint64_t rng_seed = 1; for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) { for (const int32_t dim : DIMS) { for (const int32_t nb : NBS) { @@ -128,22 +121,17 @@ class Benchmark_Faiss_Hnsw : public Benchmark_knowhere, public ::testing::Test { conf[knowhere::meta::METRIC_TYPE] = DISTANCE_TYPES[distance_type]; conf[knowhere::meta::DIM] = dim; - std::vector params = {(int)distance_type, dim, nb}; + std::vector golden_params = {(int)distance_type, dim, nb}; + const uint64_t rng_seed = get_params_hash(golden_params); auto default_ds_ptr = GenDataSet(nb, dim, rng_seed); - rng_seed += 1; // create a golden index - std::string golden_index_file_name = get_index_name( - ann_test_name_, golden_index_type, params); - - create_index( - golden_index_type, - golden_index_file_name, - default_ds_ptr, - conf, - "golden " - ); + std::string golden_index_file_name = + get_index_name(ann_test_name_, golden_index_type, golden_params); + + create_index(golden_index_type, golden_index_file_name, default_ds_ptr, conf, + "golden "); } } } @@ -163,54 +151,39 @@ class Benchmark_Faiss_Hnsw : public Benchmark_knowhere, public ::testing::Test { // accepted refines for a given SQ type for a FP32 data type std::unordered_map> SQ_ALLOWED_REFINES_FP32 = { - { "SQ6", {"SQ8", "BF16", "FP16", "FLAT"} }, - { "SQ8", {"BF16", "FP16", "FLAT"} }, - { "BF16", {"FLAT"} }, - { "FP16", {"FLAT"} } - }; + {"SQ6", {"SQ8", "BF16", "FP16", "FLAT"}}, + {"SQ8", {"BF16", "FP16", "FLAT"}}, + {"BF16", {"FLAT"}}, + {"FP16", {"FLAT"}}}; // accepted refines for a given SQ type for a FP16 data type std::unordered_map> SQ_ALLOWED_REFINES_FP16 = { - { "SQ6", {"SQ8", "FP16"} }, - { "SQ8", {"FP16"} }, - { "BF16", {} }, - { "FP16", {} } - }; + {"SQ6", {"SQ8", "FP16"}}, {"SQ8", {"FP16"}}, {"BF16", {}}, {"FP16", {}}}; // accepted refines for a given SQ type for a BF16 data type std::unordered_map> SQ_ALLOWED_REFINES_BF16 = { - { "SQ6", {"SQ8", "BF16"} }, - { "SQ8", {"BF16"} }, - { "BF16", {} }, - { "FP16", {} } - }; + {"SQ6", {"SQ8", "BF16"}}, {"SQ8", {"BF16"}}, {"BF16", {}}, {"FP16", {}}}; // accepted refines for PQ for FP32 data type - std::vector PQ_ALLOWED_REFINES_FP32 = { - {"SQ6", "SQ8", "BF16", "FP16", "FLAT"} - }; + std::vector PQ_ALLOWED_REFINES_FP32 = {{"SQ6", "SQ8", "BF16", "FP16", "FLAT"}}; // accepted refines for PQ for FP16 data type - std::vector PQ_ALLOWED_REFINES_FP16 = { - {"SQ6", "SQ8", "FP16"} - }; + std::vector PQ_ALLOWED_REFINES_FP16 = {{"SQ6", "SQ8", "FP16"}}; // accepted refines for PQ for BF16 data type - std::vector PQ_ALLOWED_REFINES_BF16 = { - {"SQ6", "SQ8", "BF16"} - }; + std::vector PQ_ALLOWED_REFINES_BF16 = {{"SQ6", "SQ8", "BF16"}}; }; - // TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWFLAT) { const std::string& index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_FLAT; const std::string& golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP; - uint64_t rng_seed = 1; for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) { for (const int32_t dim : DIMS) { - auto query_ds_ptr = GenDataSet(NQ, dim, rng_seed + 1234567); + // generate a query + const uint64_t query_rng_seed = get_params_hash({(int)distance_type, dim}); + auto query_ds_ptr = GenDataSet(NQ, dim, query_rng_seed); for (const int32_t nb : NBS) { knowhere::Json conf = cfg_; @@ -219,47 +192,39 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWFLAT) { conf[knowhere::meta::ROWS] = nb; conf[knowhere::meta::INDEX_TYPE] = index_type; + std::vector golden_params = {(int)distance_type, dim, nb}; std::vector params = {(int)distance_type, dim, nb}; // generate a default dataset + const uint64_t rng_seed = get_params_hash(golden_params); auto default_ds_ptr = GenDataSet(nb, dim, rng_seed); - rng_seed += 1; // get a golden result - std::string golden_index_file_name = get_index_name( - ann_test_name_, golden_index_type, params); + std::string golden_index_file_name = + get_index_name(ann_test_name_, golden_index_type, golden_params); - auto golden_index = create_index( - golden_index_type, - golden_index_file_name, - default_ds_ptr, - conf, - "golden " - ); + auto golden_index = create_index(golden_index_type, golden_index_file_name, + default_ds_ptr, conf, "golden "); auto golden_result = golden_index.Search(query_ds_ptr, conf, nullptr); - // fp32 candidate printf("\nProcessing HNSW,Flat fp32 for %s distance, dim=%d, nrows=%d\n", - DISTANCE_TYPES[distance_type].c_str(), dim, nb); + DISTANCE_TYPES[distance_type].c_str(), dim, nb); - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); // fp16 candidate printf("\nProcessing HNSW,Flat fp16 for %s distance, dim=%d, nrows=%d\n", - DISTANCE_TYPES[distance_type].c_str(), dim, nb); + DISTANCE_TYPES[distance_type].c_str(), dim, nb); - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); // bf32 candidate printf("\nProcessing HNSW,Flat bf16 for %s distance, dim=%d, nrows=%d\n", - DISTANCE_TYPES[distance_type].c_str(), dim, nb); + DISTANCE_TYPES[distance_type].c_str(), dim, nb); - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); } } } @@ -270,62 +235,59 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWSQ) { const std::string& index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_SQ; const std::string& golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP; - uint64_t rng_seed = 1; for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) { for (const int32_t dim : DIMS) { - auto query_ds_ptr = GenDataSet(NQ, dim, rng_seed + 1234567); + // generate a query + const uint64_t query_rng_seed = get_params_hash({(int)distance_type, dim}); + auto query_ds_ptr = GenDataSet(NQ, dim, query_rng_seed); for (const int32_t nb : NBS) { - for (size_t sq_type = 0; sq_type < SQ_TYPES.size(); sq_type++) { - knowhere::Json conf = cfg_; - conf[knowhere::meta::METRIC_TYPE] = DISTANCE_TYPES[distance_type]; - conf[knowhere::meta::DIM] = dim; - conf[knowhere::meta::ROWS] = nb; - conf[knowhere::meta::INDEX_TYPE] = index_type; - conf[knowhere::indexparam::SQ_TYPE] = SQ_TYPES[sq_type]; + // create golden conf + knowhere::Json conf_golden = cfg_; + conf_golden[knowhere::meta::METRIC_TYPE] = DISTANCE_TYPES[distance_type]; + conf_golden[knowhere::meta::DIM] = dim; + conf_golden[knowhere::meta::ROWS] = nb; - std::vector params = {(int)distance_type, dim, nb, (int)sq_type}; + std::vector golden_params = {(int)distance_type, dim, nb}; - // generate a default dataset - auto default_ds_ptr = GenDataSet(nb, dim, rng_seed); - rng_seed += 1; + // generate a default dataset + const uint64_t rng_seed = get_params_hash(golden_params); + auto default_ds_ptr = GenDataSet(nb, dim, rng_seed); - // get a golden result - std::string golden_index_file_name = get_index_name( - ann_test_name_, golden_index_type, params); + // get a golden result + std::string golden_index_file_name = + get_index_name(ann_test_name_, golden_index_type, golden_params); + + auto golden_index = create_index(golden_index_type, golden_index_file_name, + default_ds_ptr, conf_golden, "golden "); - auto golden_index = create_index( - golden_index_type, - golden_index_file_name, - default_ds_ptr, - conf, - "golden " - ); + auto golden_result = golden_index.Search(query_ds_ptr, conf_golden, nullptr); - auto golden_result = golden_index.Search(query_ds_ptr, conf, nullptr); + // go SQ + for (size_t sq_type = 0; sq_type < SQ_TYPES.size(); sq_type++) { + knowhere::Json conf = conf_golden; + conf[knowhere::meta::INDEX_TYPE] = index_type; + conf[knowhere::indexparam::SQ_TYPE] = SQ_TYPES[sq_type]; + std::vector params = {(int)distance_type, dim, nb, (int)sq_type}; // fp32 candidate printf("\nProcessing HNSW,SQ(%s) fp32 for %s distance, dim=%d, nrows=%d\n", - SQ_TYPES[sq_type].c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb); + SQ_TYPES[sq_type].c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb); - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); // fp16 candidate printf("\nProcessing HNSW,SQ(%s) fp16 for %s distance, dim=%d, nrows=%d\n", - SQ_TYPES[sq_type].c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb); + SQ_TYPES[sq_type].c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb); - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); // bf16 candidate printf("\nProcessing HNSW,SQ(%s) bf16 for %s distance, dim=%d, nrows=%d\n", - SQ_TYPES[sq_type].c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb); - - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + SQ_TYPES[sq_type].c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb); + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); // test refines for FP32 { @@ -336,19 +298,16 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWSQ) { conf_refine["refine_k"] = 1.5; conf_refine["refine_type"] = allowed_refs[allowed_ref_idx]; - std::vector params_refine = - {(int)distance_type, dim, nb, (int)sq_type, (int)allowed_ref_idx}; + std::vector params_refine = {(int)distance_type, dim, nb, (int)sq_type, + (int)allowed_ref_idx}; // fp32 candidate printf("\nProcessing HNSW,SQ(%s) with %s refine, fp32 for %s distance, dim=%d, nrows=%d\n", - SQ_TYPES[sq_type].c_str(), - allowed_refs[allowed_ref_idx].c_str(), - DISTANCE_TYPES[distance_type].c_str(), - dim, - nb); - - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, conf_refine); + SQ_TYPES[sq_type].c_str(), allowed_refs[allowed_ref_idx].c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb); + + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params_refine, conf_refine); } } @@ -361,19 +320,16 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWSQ) { conf_refine["refine_k"] = 1.5; conf_refine["refine_type"] = allowed_refs[allowed_ref_idx]; - std::vector params_refine = - {(int)distance_type, dim, nb, (int)sq_type, (int)allowed_ref_idx}; + std::vector params_refine = {(int)distance_type, dim, nb, (int)sq_type, + (int)allowed_ref_idx}; // fp16 candidate printf("\nProcessing HNSW,SQ(%s) with %s refine, fp16 for %s distance, dim=%d, nrows=%d\n", - SQ_TYPES[sq_type].c_str(), - allowed_refs[allowed_ref_idx].c_str(), - DISTANCE_TYPES[distance_type].c_str(), - dim, - nb); - - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, conf_refine); + SQ_TYPES[sq_type].c_str(), allowed_refs[allowed_ref_idx].c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb); + + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params_refine, conf_refine); } } @@ -386,19 +342,16 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWSQ) { conf_refine["refine_k"] = 1.5; conf_refine["refine_type"] = allowed_refs[allowed_ref_idx]; - std::vector params_refine = - {(int)distance_type, dim, nb, (int)sq_type, (int)allowed_ref_idx}; + std::vector params_refine = {(int)distance_type, dim, nb, (int)sq_type, + (int)allowed_ref_idx}; // bf16 candidate printf("\nProcessing HNSW,SQ(%s) with %s refine, bf16 for %s distance, dim=%d, nrows=%d\n", - SQ_TYPES[sq_type].c_str(), - allowed_refs[allowed_ref_idx].c_str(), - DISTANCE_TYPES[distance_type].c_str(), - dim, - nb); - - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, conf_refine); + SQ_TYPES[sq_type].c_str(), allowed_refs[allowed_ref_idx].c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb); + + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params_refine, conf_refine); } } } @@ -407,149 +360,126 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWSQ) { } } - // TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWPQ) { const std::string& index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_PQ; const std::string& golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP; - uint64_t rng_seed = 1; for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) { for (const int32_t dim : {16}) { - auto query_ds_ptr = GenDataSet(NQ, dim, rng_seed + 1234567); + // generate a query + const uint64_t query_rng_seed = get_params_hash({(int)distance_type, dim}); + auto query_ds_ptr = GenDataSet(NQ, dim, query_rng_seed); for (const int32_t nb : NBS) { + // set up a golden cfg + knowhere::Json conf_golden = cfg_; + conf_golden[knowhere::meta::METRIC_TYPE] = DISTANCE_TYPES[distance_type]; + conf_golden[knowhere::meta::DIM] = dim; + conf_golden[knowhere::meta::ROWS] = nb; + + std::vector golden_params = {(int)distance_type, dim, nb}; + + // generate a default dataset + const uint64_t rng_seed = get_params_hash(golden_params); + auto default_ds_ptr = GenDataSet(nb, dim, rng_seed); + + // get a golden result + std::string golden_index_file_name = + get_index_name(ann_test_name_, golden_index_type, golden_params); + + auto golden_index = create_index(golden_index_type, golden_index_file_name, + default_ds_ptr, conf_golden, "golden "); + + auto golden_result = golden_index.Search(query_ds_ptr, conf_golden, nullptr); + + // go PQ for (size_t nbits_type = 0; nbits_type < NBITS.size(); nbits_type++) { - const int pq_m = 4; + const int pq_m = 8; - knowhere::Json conf = cfg_; - conf[knowhere::meta::METRIC_TYPE] = DISTANCE_TYPES[distance_type]; - conf[knowhere::meta::DIM] = dim; - conf[knowhere::meta::ROWS] = nb; + knowhere::Json conf = conf_golden; conf[knowhere::meta::INDEX_TYPE] = index_type; conf[knowhere::indexparam::NBITS] = NBITS[nbits_type]; conf[knowhere::indexparam::M] = pq_m; std::vector params = {(int)distance_type, dim, nb, pq_m, (int)nbits_type}; - // generate a default dataset - auto default_ds_ptr = GenDataSet(nb, dim, rng_seed); - rng_seed += 1; - - // get a golden result - std::string golden_index_file_name = get_index_name( - ann_test_name_, golden_index_type, params); - - auto golden_index = create_index( - golden_index_type, - golden_index_file_name, - default_ds_ptr, - conf, - "golden " - ); - - auto golden_result = golden_index.Search(query_ds_ptr, conf, nullptr); - // test fp32 candidate - printf("\nProcessing HNSW,PQ%dx%d fp32 for %s distance, dim=%d, nrows=%d\n", - pq_m, - NBITS[nbits_type], - DISTANCE_TYPES[distance_type].c_str(), - dim, - nb); + printf("\nProcessing HNSW,PQ%dx%d fp32 for %s distance, dim=%d, nrows=%d\n", pq_m, + NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb); - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); // test fp16 candidate - printf("\nProcessing HNSW,PQ%dx%d fp16 for %s distance, dim=%d, nrows=%d\n", - pq_m, - NBITS[nbits_type], - DISTANCE_TYPES[distance_type].c_str(), - dim, - nb); + printf("\nProcessing HNSW,PQ%dx%d fp16 for %s distance, dim=%d, nrows=%d\n", pq_m, + NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb); - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); // test bf16 candidate - printf("\nProcessing HNSW,PQ%dx%d bf16 for %s distance, dim=%d, nrows=%d\n", - pq_m, - NBITS[nbits_type], - DISTANCE_TYPES[distance_type].c_str(), - dim, - nb); + printf("\nProcessing HNSW,PQ%dx%d bf16 for %s distance, dim=%d, nrows=%d\n", pq_m, + NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb); - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); // test refines for fp32 - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); allowed_ref_idx++) { + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); + allowed_ref_idx++) { auto conf_refine = conf; conf_refine["refine"] = true; conf_refine["refine_k"] = 1.5; conf_refine["refine_type"] = PQ_ALLOWED_REFINES_FP32[allowed_ref_idx]; - std::vector params_refine = - {(int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; + std::vector params_refine = {(int)distance_type, dim, nb, pq_m, (int)nbits_type, + (int)allowed_ref_idx}; // test fp32 candidate printf("\nProcessing HNSW,PQ%dx%d with %s refine, fp32 for %s distance, dim=%d, nrows=%d\n", - pq_m, - NBITS[nbits_type], - PQ_ALLOWED_REFINES_FP32[allowed_ref_idx].c_str(), - DISTANCE_TYPES[distance_type].c_str(), - dim, - nb); - - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, conf_refine); + pq_m, NBITS[nbits_type], PQ_ALLOWED_REFINES_FP32[allowed_ref_idx].c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb); + + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine); } // test refines for fp16 - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); allowed_ref_idx++) { + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); + allowed_ref_idx++) { auto conf_refine = conf; conf_refine["refine"] = true; conf_refine["refine_k"] = 1.5; conf_refine["refine_type"] = PQ_ALLOWED_REFINES_FP16[allowed_ref_idx]; - std::vector params_refine = - {(int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; + std::vector params_refine = {(int)distance_type, dim, nb, pq_m, (int)nbits_type, + (int)allowed_ref_idx}; // test fp16 candidate printf("\nProcessing HNSW,PQ%dx%d with %s refine, fp16 for %s distance, dim=%d, nrows=%d\n", - pq_m, - NBITS[nbits_type], - PQ_ALLOWED_REFINES_FP16[allowed_ref_idx].c_str(), - DISTANCE_TYPES[distance_type].c_str(), - dim, - nb); - - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, conf_refine); + pq_m, NBITS[nbits_type], PQ_ALLOWED_REFINES_FP16[allowed_ref_idx].c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb); + + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine); } // test refines for bf16 - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); allowed_ref_idx++) { + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); + allowed_ref_idx++) { auto conf_refine = conf; conf_refine["refine"] = true; conf_refine["refine_k"] = 1.5; conf_refine["refine_type"] = PQ_ALLOWED_REFINES_BF16[allowed_ref_idx]; - std::vector params_refine = - {(int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; + std::vector params_refine = {(int)distance_type, dim, nb, pq_m, (int)nbits_type, + (int)allowed_ref_idx}; // test bf16 candidate printf("\nProcessing HNSW,PQ%dx%d with %s refine, bf16 for %s distance, dim=%d, nrows=%d\n", - pq_m, - NBITS[nbits_type], - PQ_ALLOWED_REFINES_BF16[allowed_ref_idx].c_str(), - DISTANCE_TYPES[distance_type].c_str(), - dim, - nb); - - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, conf_refine); + pq_m, NBITS[nbits_type], PQ_ALLOWED_REFINES_BF16[allowed_ref_idx].c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb); + + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine); } } } @@ -562,20 +492,40 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWPRQ) { const std::string& index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_PRQ; const std::string& golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP; - uint64_t rng_seed = 1; for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) { for (const int32_t dim : {16}) { - auto query_ds_ptr = GenDataSet(NQ, dim, rng_seed + 1234567); + // generate a query + const uint64_t query_rng_seed = get_params_hash({(int)distance_type, dim}); + auto query_ds_ptr = GenDataSet(NQ, dim, query_rng_seed); for (const int32_t nb : NBS) { + // set up a golden cfg + knowhere::Json conf_golden = cfg_; + conf_golden[knowhere::meta::METRIC_TYPE] = DISTANCE_TYPES[distance_type]; + conf_golden[knowhere::meta::DIM] = dim; + conf_golden[knowhere::meta::ROWS] = nb; + + std::vector golden_params = {(int)distance_type, dim, nb}; + + // generate a default dataset + const uint64_t rng_seed = get_params_hash(golden_params); + auto default_ds_ptr = GenDataSet(nb, dim, rng_seed); + + // get a golden result + std::string golden_index_file_name = + get_index_name(ann_test_name_, golden_index_type, golden_params); + + auto golden_index = create_index(golden_index_type, golden_index_file_name, + default_ds_ptr, conf_golden, "golden "); + + auto golden_result = golden_index.Search(query_ds_ptr, conf_golden, nullptr); + + // go PRQ for (size_t nbits_type = 0; nbits_type < NBITS.size(); nbits_type++) { const int prq_m = 4; const int prq_num = 2; - knowhere::Json conf = cfg_; - conf[knowhere::meta::METRIC_TYPE] = DISTANCE_TYPES[distance_type]; - conf[knowhere::meta::DIM] = dim; - conf[knowhere::meta::ROWS] = nb; + knowhere::Json conf = conf_golden; conf[knowhere::meta::INDEX_TYPE] = index_type; conf[knowhere::indexparam::NBITS] = NBITS[nbits_type]; conf[knowhere::indexparam::M] = prq_m; @@ -583,137 +533,88 @@ TEST_F(Benchmark_Faiss_Hnsw, TEST_HNSWPRQ) { std::vector params = {(int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type}; - // generate a default dataset - auto default_ds_ptr = GenDataSet(nb, dim, rng_seed); - rng_seed += 1; - - // get a golden result - std::string golden_index_file_name = get_index_name( - ann_test_name_, golden_index_type, params); - - auto golden_index = create_index( - golden_index_type, - golden_index_file_name, - default_ds_ptr, - conf, - "golden " - ); - - auto golden_result = golden_index.Search(query_ds_ptr, conf, nullptr); - // test fp32 candidate - printf("\nProcessing HNSW,PRQ%dx%dx%d fp32 for %s distance, dim=%d, nrows=%d\n", - prq_num, - prq_m, - NBITS[nbits_type], - DISTANCE_TYPES[distance_type].c_str(), - dim, - nb); + printf("\nProcessing HNSW,PRQ%dx%dx%d fp32 for %s distance, dim=%d, nrows=%d\n", prq_num, prq_m, + NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb); - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); // test fp16 candidate - printf("\nProcessing HNSW,PRQ%dx%dx%d fp16 for %s distance, dim=%d, nrows=%d\n", - prq_num, - prq_m, - NBITS[nbits_type], - DISTANCE_TYPES[distance_type].c_str(), - dim, - nb); + printf("\nProcessing HNSW,PRQ%dx%dx%d fp16 for %s distance, dim=%d, nrows=%d\n", prq_num, prq_m, + NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb); - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); // test bf16 candidate - printf("\nProcessing HNSW,PRQ%dx%dx%d bf16 for %s distance, dim=%d, nrows=%d\n", - prq_num, - prq_m, - NBITS[nbits_type], - DISTANCE_TYPES[distance_type].c_str(), - dim, - nb); + printf("\nProcessing HNSW,PRQ%dx%dx%d bf16 for %s distance, dim=%d, nrows=%d\n", prq_num, prq_m, + NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb); - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf); // test fp32 refines - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); allowed_ref_idx++) { + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); + allowed_ref_idx++) { auto conf_refine = conf; conf_refine["refine"] = true; conf_refine["refine_k"] = 1.5; conf_refine["refine_type"] = PQ_ALLOWED_REFINES_FP32[allowed_ref_idx]; - std::vector params_refine = - {(int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, (int)allowed_ref_idx}; + std::vector params_refine = { + (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, (int)allowed_ref_idx}; // printf("\nProcessing HNSW,PRQ%dx%dx%d with %s refine, fp32 for %s distance, dim=%d, nrows=%d\n", - prq_num, - prq_m, - NBITS[nbits_type], - PQ_ALLOWED_REFINES_FP32[allowed_ref_idx].c_str(), - DISTANCE_TYPES[distance_type].c_str(), - dim, - nb); + prq_num, prq_m, NBITS[nbits_type], PQ_ALLOWED_REFINES_FP32[allowed_ref_idx].c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb); // test a candidate - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, conf_refine); + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine); } // test fp16 refines - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); allowed_ref_idx++) { + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); + allowed_ref_idx++) { auto conf_refine = conf; conf_refine["refine"] = true; conf_refine["refine_k"] = 1.5; conf_refine["refine_type"] = PQ_ALLOWED_REFINES_FP16[allowed_ref_idx]; - std::vector params_refine = - {(int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, (int)allowed_ref_idx}; + std::vector params_refine = { + (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, (int)allowed_ref_idx}; // printf("\nProcessing HNSW,PRQ%dx%dx%d with %s refine, fp16 for %s distance, dim=%d, nrows=%d\n", - prq_num, - prq_m, - NBITS[nbits_type], - PQ_ALLOWED_REFINES_FP16[allowed_ref_idx].c_str(), - DISTANCE_TYPES[distance_type].c_str(), - dim, - nb); + prq_num, prq_m, NBITS[nbits_type], PQ_ALLOWED_REFINES_FP16[allowed_ref_idx].c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb); // test a candidate - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, conf_refine); + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine); } // test bf16 refines - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); allowed_ref_idx++) { + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); + allowed_ref_idx++) { auto conf_refine = conf; conf_refine["refine"] = true; conf_refine["refine_k"] = 1.5; conf_refine["refine_type"] = PQ_ALLOWED_REFINES_BF16[allowed_ref_idx]; - std::vector params_refine = - {(int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, (int)allowed_ref_idx}; + std::vector params_refine = { + (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, (int)allowed_ref_idx}; // printf("\nProcessing HNSW,PRQ%dx%dx%d with %s refine, bf16 for %s distance, dim=%d, nrows=%d\n", - prq_num, - prq_m, - NBITS[nbits_type], - PQ_ALLOWED_REFINES_BF16[allowed_ref_idx].c_str(), - DISTANCE_TYPES[distance_type].c_str(), - dim, - nb); + prq_num, prq_m, NBITS[nbits_type], PQ_ALLOWED_REFINES_BF16[allowed_ref_idx].c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb); // test a candidate - test_hnsw( - default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, conf_refine); + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine); } } } } } } - diff --git a/benchmark/hdf5/benchmark_knowhere.h b/benchmark/hdf5/benchmark_knowhere.h index fcfd4c250..1ebaef45a 100644 --- a/benchmark/hdf5/benchmark_knowhere.h +++ b/benchmark/hdf5/benchmark_knowhere.h @@ -90,13 +90,10 @@ class Benchmark_knowhere : public Benchmark_hdf5 { index.Deserialize(binary_set, conf); } - template + template static std::string - get_index_name( - const std::string& ann_test_name, - const std::string& index_type, - const std::vector& params - ) { + get_index_name(const std::string& ann_test_name, const std::string& index_type, + const std::vector& params) { std::string params_str = ""; for (size_t i = 0; i < params.size(); i++) { params_str += "_" + std::to_string(params[i]); @@ -118,45 +115,32 @@ class Benchmark_knowhere : public Benchmark_hdf5 { return this->get_index_name(ann_test_name_, index_type_, params); } - template + template knowhere::Index - create_index( - const std::string& index_type, - const std::string& index_file_name, - const knowhere::DataSetPtr& default_ds_ptr, - const knowhere::Json& conf, - const std::optional& additional_name = std::nullopt - ) { + create_index(const std::string& index_type, const std::string& index_file_name, + const knowhere::DataSetPtr& default_ds_ptr, const knowhere::Json& conf, + const std::optional& additional_name = std::nullopt) { std::string additional_name_s = additional_name.value_or(""); - printf("[%.3f s] Creating %sindex \"%s\"\n", - get_time_diff(), - additional_name_s.c_str(), - index_type.c_str()); + printf("[%.3f s] Creating %sindex \"%s\"\n", get_time_diff(), additional_name_s.c_str(), index_type.c_str()); auto version = knowhere::Version::GetCurrentVersion().VersionNumber(); auto index = knowhere::IndexFactory::Instance().Create(index_type, version); try { - printf("[%.3f s] Reading %sindex file: %s\n", - get_time_diff(), - additional_name_s.c_str(), - index_file_name.c_str()); + printf("[%.3f s] Reading %sindex file: %s\n", get_time_diff(), additional_name_s.c_str(), + index_file_name.c_str()); read_index(index.value(), index_file_name, conf); } catch (...) { - printf("[%.3f s] Building %sindex all on %ld vectors\n", - get_time_diff(), - additional_name_s.c_str(), - default_ds_ptr->GetRows()); + printf("[%.3f s] Building %sindex all on %ld vectors\n", get_time_diff(), additional_name_s.c_str(), + default_ds_ptr->GetRows()); auto base = knowhere::ConvertToDataTypeIfNeeded(default_ds_ptr); index.value().Build(base, conf); - printf("[%.3f s] Writing %sindex file: %s\n", - get_time_diff(), - additional_name_s.c_str(), - index_file_name.c_str()); + printf("[%.3f s] Writing %sindex file: %s\n", get_time_diff(), additional_name_s.c_str(), + index_file_name.c_str()); write_index(index.value(), index_file_name, conf); } @@ -167,12 +151,7 @@ class Benchmark_knowhere : public Benchmark_hdf5 { template knowhere::Index create_index(const std::string& index_file_name, const knowhere::Json& conf) { - auto idx = this->create_index( - index_type_, - index_file_name, - knowhere::GenDataSet(nb_, dim_, xb_), - conf - ); + auto idx = this->create_index(index_type_, index_file_name, knowhere::GenDataSet(nb_, dim_, xb_), conf); index_ = idx; return idx; } @@ -182,13 +161,8 @@ class Benchmark_knowhere : public Benchmark_hdf5 { golden_index_type_ = knowhere::IndexEnum::INDEX_FAISS_IDMAP; std::string golden_index_file_name = ann_test_name_ + "_" + golden_index_type_ + "_GOLDEN" + ".index"; - auto idx = this->create_index( - golden_index_type_, - golden_index_file_name, - knowhere::GenDataSet(nb_, dim_, xb_), - conf, - "golden " - ); + auto idx = this->create_index(golden_index_type_, golden_index_file_name, + knowhere::GenDataSet(nb_, dim_, xb_), conf, "golden "); golden_index_ = idx; return idx; } diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index d9b98e6df..697119789 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -156,8 +156,8 @@ constexpr const char* EF = "ef"; constexpr const char* OVERVIEW_LEVELS = "overview_levels"; // FAISS additional Params -constexpr const char* SQ_TYPE = "sq_type"; // for IVF_SQ and HNSW_SQ -constexpr const char* PRQ_NUM = "nrq"; // for PRQ, number of redisual quantizers +constexpr const char* SQ_TYPE = "sq_type"; // for IVF_SQ and HNSW_SQ +constexpr const char* PRQ_NUM = "nrq"; // for PRQ, number of redisual quantizers // Sparse Params constexpr const char* DROP_RATIO_BUILD = "drop_ratio_build"; diff --git a/include/knowhere/utils.h b/include/knowhere/utils.h index 72475d771..866942279 100644 --- a/include/knowhere/utils.h +++ b/include/knowhere/utils.h @@ -109,11 +109,8 @@ GetKey(const std::string& name) { template inline DataSetPtr -data_type_conversion( - const DataSet& src, - const std::optional start = std::nullopt, - const std::optional count = std::nullopt -) { +data_type_conversion(const DataSet& src, const std::optional start = std::nullopt, + const std::optional count = std::nullopt) { auto dim = src.GetDim(); auto rows = src.GetRows(); @@ -152,7 +149,8 @@ data_type_conversion( // * invalid start, count values -> returns nullptr template inline DataSetPtr -ConvertFromDataTypeIfNeeded(const DataSetPtr& ds, const std::optional start = std::nullopt, const std::optional count = std::nullopt) { +ConvertFromDataTypeIfNeeded(const DataSetPtr& ds, const std::optional start = std::nullopt, + const std::optional count = std::nullopt) { if constexpr (std::is_same_v::type>) { if (!start.has_value() && !count.has_value()) { return ds; @@ -162,11 +160,11 @@ ConvertFromDataTypeIfNeeded(const DataSetPtr& ds, const std::optional s return data_type_conversion::type>(*ds, start, count); } - // Convert DataSet from float to DataType template inline DataSetPtr -ConvertToDataTypeIfNeeded(const DataSetPtr& ds, const std::optional start = std::nullopt, const std::optional count = std::nullopt) { +ConvertToDataTypeIfNeeded(const DataSetPtr& ds, const std::optional start = std::nullopt, + const std::optional count = std::nullopt) { if constexpr (std::is_same_v::type>) { if (!start.has_value() && !count.has_value()) { return ds; diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 6247d3dbb..a3f3c77c0 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -280,42 +280,34 @@ class BaseFaissRegularIndexNode : public BaseFaissIndexNode { }; // -enum class DataFormatEnum { - fp32, - fp16, - bf16 -}; +enum class DataFormatEnum { fp32, fp16, bf16 }; -template +template struct DataType2EnumHelper {}; -template<> struct DataType2EnumHelper { +template <> +struct DataType2EnumHelper { static constexpr DataFormatEnum value = DataFormatEnum::fp32; }; -template<> struct DataType2EnumHelper { +template <> +struct DataType2EnumHelper { static constexpr DataFormatEnum value = DataFormatEnum::fp16; }; -template<> struct DataType2EnumHelper { +template <> +struct DataType2EnumHelper { static constexpr DataFormatEnum value = DataFormatEnum::bf16; }; -template +template static constexpr DataFormatEnum datatype_v = DataType2EnumHelper::value; - - // namespace { // -bool convert_rows( - const void* const __restrict src_in, - float* const __restrict dst, - const DataFormatEnum data_format, - const size_t start_row, - const size_t nrows, - const size_t dim -) { +bool +convert_rows(const void* const __restrict src_in, float* const __restrict dst, const DataFormatEnum data_format, + const size_t start_row, const size_t nrows, const size_t dim) { if (data_format == DataFormatEnum::fp16) { const knowhere::fp16* const src = reinterpret_cast(src_in); for (size_t i = 0; i < nrows * dim; i++) { @@ -337,7 +329,8 @@ bool convert_rows( } // -DataSetPtr convert_ds_to_float(const DataSetPtr& src, DataFormatEnum data_format) { +DataSetPtr +convert_ds_to_float(const DataSetPtr& src, DataFormatEnum data_format) { if (data_format == DataFormatEnum::fp32) { return src; } else if (data_format == DataFormatEnum::fp16) { @@ -349,11 +342,8 @@ DataSetPtr convert_ds_to_float(const DataSetPtr& src, DataFormatEnum data_format return nullptr; } -Status add_to_index( - faiss::Index* const __restrict index, - const DataSetPtr& dataset, - const DataFormatEnum data_format -) { +Status +add_to_index(faiss::Index* const __restrict index, const DataSetPtr& dataset, const DataFormatEnum data_format) { const auto* data = dataset->GetTensor(); const auto rows = dataset->GetRows(); const auto dim = dataset->GetDim(); @@ -384,7 +374,7 @@ Status add_to_index( return Status::success; } -} +} // namespace // class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { @@ -469,7 +459,6 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { cur_query = cur_query_tmp.data(); } - // set up local results faiss::idx_t* const __restrict local_ids = ids.get() + k * idx; float* const __restrict local_distances = distances.get() + k * idx; @@ -541,7 +530,8 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { // yes, wrap both base and refine index knowhere::IndexWrapperCosine cosine_wrapper( index_refine->refine_index, - dynamic_cast(index_hnsw->storage)->get_inverse_l2_norms()); + dynamic_cast(index_hnsw->storage) + ->get_inverse_l2_norms()); // create a temporary refine index which does not own faiss::IndexRefine tmp_refine(base_wrapper, &cosine_wrapper); @@ -632,7 +622,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { bool WhetherPerformBruteForceSearch(const BaseConfig& cfg, const BitsetView& bitset) const { constexpr float kHnswSearchKnnBFFilterThreshold = 0.93f; - constexpr float kHnswSearchRangeBFFilterThreshold = 0.97f; + // constexpr float kHnswSearchRangeBFFilterThreshold = 0.97f; constexpr float kHnswSearchBFTopkThreshold = 0.5f; auto k = cfg.k.value(); @@ -736,11 +726,11 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode { if (data_format == DataFormatEnum::fp32) { hnsw_index = std::make_unique(dim, hnsw_cfg.M.value()); } else if (data_format == DataFormatEnum::fp16) { - hnsw_index = std::make_unique( - dim, faiss::ScalarQuantizer::QT_fp16, hnsw_cfg.M.value()); + hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_fp16, + hnsw_cfg.M.value()); } else if (data_format == DataFormatEnum::bf16) { - hnsw_index = std::make_unique( - dim, faiss::ScalarQuantizer::QT_bf16, hnsw_cfg.M.value()); + hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_bf16, + hnsw_cfg.M.value()); } else { LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value(); return Status::invalid_metric_type; @@ -749,11 +739,11 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode { if (data_format == DataFormatEnum::fp32) { hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); } else if (data_format == DataFormatEnum::fp16) { - hnsw_index = std::make_unique( - dim, faiss::ScalarQuantizer::QT_fp16, hnsw_cfg.M.value(), metric.value()); + hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_fp16, + hnsw_cfg.M.value(), metric.value()); } else if (data_format == DataFormatEnum::bf16) { - hnsw_index = std::make_unique( - dim, faiss::ScalarQuantizer::QT_bf16, hnsw_cfg.M.value(), metric.value()); + hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_bf16, + hnsw_cfg.M.value(), metric.value()); } else { LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value(); return Status::invalid_metric_type; @@ -786,18 +776,15 @@ class BaseFaissRegularIndexHNSWFlatNodeTemplate : public BaseFaissRegularIndexHN } }; - namespace { // a supporting function expected get_sq_quantizer_type(const std::string& sq_type) { - std::map sq_types = { - {"SQ6", faiss::ScalarQuantizer::QT_6bit}, - {"SQ8", faiss::ScalarQuantizer::QT_8bit}, - {"FP16", faiss::ScalarQuantizer::QT_fp16}, - {"BF16", faiss::ScalarQuantizer::QT_bf16} - }; + std::map sq_types = {{"SQ6", faiss::ScalarQuantizer::QT_6bit}, + {"SQ8", faiss::ScalarQuantizer::QT_8bit}, + {"FP16", faiss::ScalarQuantizer::QT_fp16}, + {"BF16", faiss::ScalarQuantizer::QT_bf16}}; // todo: tolower auto itr = sq_types.find(sq_type); @@ -825,8 +812,7 @@ is_flat_refine(const std::optional& refine_type) { auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); if (!refine_sq_type.has_value()) { LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << refine_type.value(); - return expected::Err( - Status::invalid_args, fmt::format("invalid refine type ({})", refine_type.value())); + return expected::Err(Status::invalid_args, fmt::format("invalid refine type ({})", refine_type.value())); } return false; @@ -834,18 +820,14 @@ is_flat_refine(const std::optional& refine_type) { // pick a refine index expected> -pick_refine_index( - const DataFormatEnum data_format, - const std::optional& refine_type, - std::unique_ptr&& hnsw_index -) { +pick_refine_index(const DataFormatEnum data_format, const std::optional& refine_type, + std::unique_ptr&& hnsw_index) { // yes // grab a type of a refine index expected is_fp32_flat = is_flat_refine(refine_type); if (!is_fp32_flat.has_value()) { - return expected>::Err( - Status::invalid_args, ""); + return expected>::Err(Status::invalid_args, ""); } const bool is_fp32_flat_v = is_fp32_flat.value(); @@ -854,10 +836,8 @@ pick_refine_index( if (data_format == DataFormatEnum::fp16) { // make sure that we're using fp16 refine auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); - if (!(refine_sq_type.has_value() && ( - refine_sq_type.value() != faiss::ScalarQuantizer::QT_bf16 && - !is_fp32_flat_v))) { - + if (!(refine_sq_type.has_value() && + (refine_sq_type.value() != faiss::ScalarQuantizer::QT_bf16 && !is_fp32_flat_v))) { LOG_KNOWHERE_ERROR_ << "fp16 input data does not accept bf16 or fp32 as a refine index."; return expected>::Err( Status::invalid_args, "fp16 input data does not accept bf16 or fp32 as a refine index."); @@ -867,10 +847,8 @@ pick_refine_index( if (data_format == DataFormatEnum::bf16) { // make sure that we're using bf16 refine auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); - if (!(refine_sq_type.has_value() && ( - refine_sq_type.value() != faiss::ScalarQuantizer::QT_fp16 && - !is_fp32_flat_v))) { - + if (!(refine_sq_type.has_value() && + (refine_sq_type.value() != faiss::ScalarQuantizer::QT_fp16 && !is_fp32_flat_v))) { LOG_KNOWHERE_ERROR_ << "bf16 input data does not accept fp16 or fp32 as a refine index."; return expected>::Err( Status::invalid_args, "bf16 input data does not accept fp16 or fp32 as a refine index."); @@ -904,10 +882,7 @@ pick_refine_index( // create an sq auto sq_refine = std::make_unique( - local_hnsw_index->storage->d, - refine_sq_type.value(), - local_hnsw_index->storage->metric_type - ); + local_hnsw_index->storage->d, refine_sq_type.value(), local_hnsw_index->storage->metric_type); auto refine_index = std::make_unique(local_hnsw_index.get(), sq_refine.get()); @@ -926,9 +901,10 @@ pick_refine_index( // class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { - public: - BaseFaissRegularIndexHNSWSQNode(const int32_t& version, const Object& object, DataFormatEnum data_format) : - BaseFaissRegularIndexHNSWNode(version, object, data_format) {} + public: + BaseFaissRegularIndexHNSWSQNode(const int32_t& version, const Object& object, DataFormatEnum data_format) + : BaseFaissRegularIndexHNSWNode(version, object, data_format) { + } std::unique_ptr CreateConfig() const override { @@ -940,8 +916,9 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { return knowhere::IndexEnum::INDEX_FAISS_HNSW_SQ; } - protected: - Status TrainInternal(const DataSetPtr dataset, const Config& cfg) override { + protected: + Status + TrainInternal(const DataSetPtr dataset, const Config& cfg) override { // number of rows auto rows = dataset->GetRows(); // dimensionality of the data @@ -968,11 +945,9 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { std::unique_ptr hnsw_index; if (is_cosine) { - hnsw_index = std::make_unique( - dim, sq_type.value(), hnsw_cfg.M.value()); + hnsw_index = std::make_unique(dim, sq_type.value(), hnsw_cfg.M.value()); } else { - hnsw_index = std::make_unique( - dim, sq_type.value(), hnsw_cfg.M.value(), metric.value()); + hnsw_index = std::make_unique(dim, sq_type.value(), hnsw_cfg.M.value(), metric.value()); } hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); @@ -1014,19 +989,20 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { } }; -template +template class BaseFaissRegularIndexHNSWSQNodeTemplate : public BaseFaissRegularIndexHNSWSQNode { - public: - BaseFaissRegularIndexHNSWSQNodeTemplate(const int32_t& version, const Object& object) : - BaseFaissRegularIndexHNSWSQNode(version, object, datatype_v) {} + public: + BaseFaissRegularIndexHNSWSQNodeTemplate(const int32_t& version, const Object& object) + : BaseFaissRegularIndexHNSWSQNode(version, object, datatype_v) { + } }; - // this index trains PQ and HNSW+FLAT separately, then constructs HNSW+PQ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { - public: - BaseFaissRegularIndexHNSWPQNode(const int32_t& version, const Object& object, DataFormatEnum data_format) : - BaseFaissRegularIndexHNSWNode(version, object, data_format) {} + public: + BaseFaissRegularIndexHNSWPQNode(const int32_t& version, const Object& object, DataFormatEnum data_format) + : BaseFaissRegularIndexHNSWNode(version, object, data_format) { + } std::unique_ptr CreateConfig() const override { @@ -1038,10 +1014,11 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { return knowhere::IndexEnum::INDEX_FAISS_HNSW_PQ; } - protected: + protected: std::unique_ptr tmp_index_pq; - Status TrainInternal(const DataSetPtr dataset, const Config& cfg) override { + Status + TrainInternal(const DataSetPtr dataset, const Config& cfg) override { // number of rows auto rows = dataset->GetRows(); // dimensionality of the data @@ -1076,10 +1053,10 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { if (is_cosine) { pq_index = std::make_unique(dim, hnsw_cfg.m.value(), hnsw_cfg.nbits.value()); } else { - pq_index = std::make_unique(dim, hnsw_cfg.m.value(), hnsw_cfg.nbits.value(), metric.value()); + pq_index = + std::make_unique(dim, hnsw_cfg.m.value(), hnsw_cfg.nbits.value(), metric.value()); } - // should refine be used? std::unique_ptr final_index; if (hnsw_cfg.refine.value_or(false)) { @@ -1154,8 +1131,7 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { // check if we have a refine available. faiss::IndexHNSW* index_hnsw = nullptr; - faiss::IndexRefine* const index_refine = - dynamic_cast(index.get()); + faiss::IndexRefine* const index_refine = dynamic_cast(index.get()); if (index_refine != nullptr) { index_hnsw = dynamic_cast(index_refine->base_index); @@ -1172,9 +1148,9 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { index_hnsw_pq = std::make_unique(); } - // C++ slicing - static_cast(*index_hnsw_pq) = - std::move(static_cast(*index_hnsw)); + // C++ slicing. + // we can't use move, because faiss::IndexHNSW overrides a destructor. + static_cast(*index_hnsw_pq) = static_cast(*index_hnsw); // clear out the storage delete index_hnsw->storage; @@ -1201,19 +1177,20 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { } }; -template +template class BaseFaissRegularIndexHNSWPQNodeTemplate : public BaseFaissRegularIndexHNSWPQNode { - public: - BaseFaissRegularIndexHNSWPQNodeTemplate(const int32_t& version, const Object& object) : - BaseFaissRegularIndexHNSWPQNode(version, object, datatype_v) {} + public: + BaseFaissRegularIndexHNSWPQNodeTemplate(const int32_t& version, const Object& object) + : BaseFaissRegularIndexHNSWPQNode(version, object, datatype_v) { + } }; - // this index trains PRQ and HNSW+FLAT separately, then constructs HNSW+PRQ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { - public: - BaseFaissRegularIndexHNSWPRQNode(const int32_t& version, const Object& object, DataFormatEnum data_format) : - BaseFaissRegularIndexHNSWNode(version, object, data_format) {} + public: + BaseFaissRegularIndexHNSWPRQNode(const int32_t& version, const Object& object, DataFormatEnum data_format) + : BaseFaissRegularIndexHNSWNode(version, object, data_format) { + } std::unique_ptr CreateConfig() const override { @@ -1225,16 +1202,15 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { return knowhere::IndexEnum::INDEX_FAISS_HNSW_PRQ; } - protected: + protected: std::unique_ptr tmp_index_prq; - Status TrainInternal(const DataSetPtr dataset, const Config& cfg) override { + Status + TrainInternal(const DataSetPtr dataset, const Config& cfg) override { // number of rows auto rows = dataset->GetRows(); // dimensionality of the data auto dim = dataset->GetDim(); - // data - auto data = dataset->GetTensor(); // config auto hnsw_cfg = static_cast(cfg); @@ -1262,9 +1238,9 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { // prq faiss::AdditiveQuantizer::Search_type_t prq_search_type = - (metric.value() == faiss::MetricType::METRIC_INNER_PRODUCT) ? - faiss::AdditiveQuantizer::Search_type_t::ST_LUT_nonorm : - faiss::AdditiveQuantizer::Search_type_t::ST_norm_float; + (metric.value() == faiss::MetricType::METRIC_INNER_PRODUCT) + ? faiss::AdditiveQuantizer::Search_type_t::ST_LUT_nonorm + : faiss::AdditiveQuantizer::Search_type_t::ST_norm_float; std::unique_ptr prq_index; if (is_cosine) { @@ -1324,7 +1300,6 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { return Status::empty_index; } - auto data = dataset->GetTensor(); auto rows = dataset->GetRows(); try { // hnsw @@ -1349,8 +1324,7 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { // check if we have a refine available. faiss::IndexHNSW* index_hnsw = nullptr; - faiss::IndexRefine* const index_refine = - dynamic_cast(index.get()); + faiss::IndexRefine* const index_refine = dynamic_cast(index.get()); if (index_refine != nullptr) { index_hnsw = dynamic_cast(index_refine->base_index); @@ -1368,8 +1342,8 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { } // C++ slicing - static_cast(*index_hnsw_prq) = - std::move(static_cast(*index_hnsw)); + // we can't use move, because faiss::IndexHNSW overrides a destructor. + static_cast(*index_hnsw_prq) = static_cast(*index_hnsw); // clear out the storage delete index_hnsw->storage; @@ -1396,14 +1370,14 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { } }; -template +template class BaseFaissRegularIndexHNSWPRQNodeTemplate : public BaseFaissRegularIndexHNSWPRQNode { - public: - BaseFaissRegularIndexHNSWPRQNodeTemplate(const int32_t& version, const Object& object) : - BaseFaissRegularIndexHNSWPRQNode(version, object, datatype_v) {} + public: + BaseFaissRegularIndexHNSWPRQNodeTemplate(const int32_t& version, const Object& object) + : BaseFaissRegularIndexHNSWPRQNode(version, object, datatype_v) { + } }; - // KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_FLAT, BaseFaissRegularIndexHNSWFlatNodeTemplate, fp32); KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_FLAT, BaseFaissRegularIndexHNSWFlatNodeTemplate, fp16); diff --git a/src/index/hnsw/faiss_hnsw_config.h b/src/index/hnsw/faiss_hnsw_config.h index 0de816cc5..514cd8ef4 100644 --- a/src/index/hnsw/faiss_hnsw_config.h +++ b/src/index/hnsw/faiss_hnsw_config.h @@ -115,10 +115,10 @@ class FaissHnswConfig : public BaseConfig { } protected: - bool WhetherAcceptableRefineType(const std::string& refine_type) { + bool + WhetherAcceptableRefineType(const std::string& refine_type) { // 'flat' is identical to 'fp32' - std::vector allowed_list = { - "SQ6", "SQ8", "FP16", "BF16", "FP32", "FLAT"}; + std::vector allowed_list = {"SQ6", "SQ8", "FP16", "BF16", "FP32", "FLAT"}; // todo: tolower() @@ -162,15 +162,12 @@ class FaissHnswFlatConfig : public FaissHnswConfig { }; class FaissHnswSqConfig : public FaissHnswConfig { -public: + public: // user can use quant_type to control quantizer type. // we have fp16, bf16, etc, so '8', '4' and '6' is insufficient CFG_STRING sq_type; KNOHWERE_DECLARE_CONFIG(FaissHnswSqConfig) { - KNOWHERE_CONFIG_DECLARE_FIELD(sq_type) - .set_default("SQ8") - .description("scalar quantizer type") - .for_train(); + KNOWHERE_CONFIG_DECLARE_FIELD(sq_type).set_default("SQ8").description("scalar quantizer type").for_train(); }; Status @@ -202,11 +199,11 @@ class FaissHnswSqConfig : public FaissHnswConfig { return Status::success; } -private: - bool WhetherAcceptableQuantType(const std::string& sq_type) { + private: + bool + WhetherAcceptableQuantType(const std::string& sq_type) { // todo: add more - std::vector allowed_list = { - "SQ6", "SQ8", "FP16", "BF16"}; + std::vector allowed_list = {"SQ6", "SQ8", "FP16", "BF16"}; // todo: tolower() @@ -221,28 +218,20 @@ class FaissHnswSqConfig : public FaissHnswConfig { }; class FaissHnswPqConfig : public FaissHnswConfig { -public: + public: // number of subquantizers CFG_INT m; // number of bits per subquantizer CFG_INT nbits; KNOHWERE_DECLARE_CONFIG(FaissHnswPqConfig) { - KNOWHERE_CONFIG_DECLARE_FIELD(m) - .description("m") - .set_default(32) - .for_train() - .set_range(1, 65536); - KNOWHERE_CONFIG_DECLARE_FIELD(nbits) - .description("nbits") - .set_default(8) - .for_train() - .set_range(1, 16); + KNOWHERE_CONFIG_DECLARE_FIELD(m).description("m").set_default(32).for_train().set_range(1, 65536); + KNOWHERE_CONFIG_DECLARE_FIELD(nbits).description("nbits").set_default(8).for_train().set_range(1, 16); } }; class FaissHnswPrqConfig : public FaissHnswConfig { -public: + public: // number of subquantizer splits CFG_INT m; // number of residual quantizers @@ -250,20 +239,9 @@ class FaissHnswPrqConfig : public FaissHnswConfig { // number of bits per subquantizer CFG_INT nbits; KNOHWERE_DECLARE_CONFIG(FaissHnswPrqConfig) { - KNOWHERE_CONFIG_DECLARE_FIELD(m) - .description("Number of splits") - .set_default(2) - .for_train() - .set_range(1, 65536); - KNOWHERE_CONFIG_DECLARE_FIELD(nrq) - .description("Number of residual subquantizers") - .for_train() - .set_range(1, 64); - KNOWHERE_CONFIG_DECLARE_FIELD(nbits) - .description("nbits") - .set_default(8) - .for_train() - .set_range(1, 64); + KNOWHERE_CONFIG_DECLARE_FIELD(m).description("Number of splits").set_default(2).for_train().set_range(1, 65536); + KNOWHERE_CONFIG_DECLARE_FIELD(nrq).description("Number of residual subquantizers").for_train().set_range(1, 64); + KNOWHERE_CONFIG_DECLARE_FIELD(nbits).description("nbits").set_default(8).for_train().set_range(1, 64); } }; From 2006b77c5b005a7cacdaaa1d567be363d44ef55e Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Sun, 11 Aug 2024 19:52:32 -0400 Subject: [PATCH 14/21] Faster Faiss HNSW training Signed-off-by: Alexandr Guzhva --- thirdparty/faiss/faiss/impl/HNSW.cpp | 177 ++++++++++++++++++++++----- 1 file changed, 143 insertions(+), 34 deletions(-) diff --git a/thirdparty/faiss/faiss/impl/HNSW.cpp b/thirdparty/faiss/faiss/impl/HNSW.cpp index 3647f0f6f..4e6fb2944 100644 --- a/thirdparty/faiss/faiss/impl/HNSW.cpp +++ b/thirdparty/faiss/faiss/impl/HNSW.cpp @@ -381,26 +381,82 @@ void search_neighbors_to_add( // loop over neighbors size_t begin, end; hnsw.neighbor_range(currNode, level, &begin, &end); - for (size_t i = begin; i < end; i++) { - storage_idx_t nodeId = hnsw.neighbors[i]; + + // // baseline version + // for (size_t i = begin; i < end; i++) { + // storage_idx_t nodeId = hnsw.neighbors[i]; + // if (nodeId < 0) + // break; + // if (vt.get(nodeId)) + // continue; + // vt.set(nodeId); + + // float dis = qdis(nodeId); + // NodeDistFarther evE1(dis, nodeId); + + // if (results.size() < hnsw.efConstruction || results.top().d > dis) { + // results.emplace(dis, nodeId); + // candidates.emplace(dis, nodeId); + // if (results.size() > hnsw.efConstruction) { + // results.pop(); + // } + // } + // } + + // the following version processes 4 neighbors at a time + auto update_with_candidate = [&](const storage_idx_t idx, const float dis) { + if (results.size() < hnsw.efConstruction || results.top().d > dis) { + results.emplace(dis, idx); + candidates.emplace(dis, idx); + if (results.size() > hnsw.efConstruction) { + results.pop(); + } + } + }; + + int n_buffered = 0; + storage_idx_t buffered_ids[4]; + + for (size_t j = begin; j < end; j++) { + storage_idx_t nodeId = hnsw.neighbors[j]; if (nodeId < 0) break; - if (vt.get(nodeId)) + if (vt.get(nodeId)) { continue; + } vt.set(nodeId); - float dis = qdis(nodeId); - NodeDistFarther evE1(dis, nodeId); + buffered_ids[n_buffered] = nodeId; + n_buffered += 1; - if (results.size() < hnsw.efConstruction || results.top().d > dis) { - results.emplace(dis, nodeId); - candidates.emplace(dis, nodeId); - if (results.size() > hnsw.efConstruction) { - results.pop(); + if (n_buffered == 4) { + float dis[4]; + qdis.distances_batch_4( + buffered_ids[0], + buffered_ids[1], + buffered_ids[2], + buffered_ids[3], + dis[0], + dis[1], + dis[2], + dis[3]); + + for (size_t id4 = 0; id4 < 4; id4++) { + update_with_candidate(buffered_ids[id4], dis[id4]); } + + n_buffered = 0; } } + + // process leftovers + for (size_t icnt = 0; icnt < n_buffered; icnt++) { + float dis = qdis(buffered_ids[icnt]); + update_with_candidate(buffered_ids[icnt], dis); + } + } + vt.advance(); } @@ -423,18 +479,67 @@ HNSWStats greedy_update_nearest( size_t begin, end; hnsw.neighbor_range(nearest, level, &begin, &end); + // // baseline version + // size_t ndis = 0; + // for (size_t i = begin; i < end; i++) { + // storage_idx_t v = hnsw.neighbors[i]; + // if (v < 0) + // break; + // ndis += 1; + // float dis = qdis(v); + // if (dis < d_nearest) { + // nearest = v; + // d_nearest = dis; + // } + // } + + // the following version processes 4 neighbors at a time + auto update_with_candidate = [&](const storage_idx_t idx, const float dis) { + if (dis < d_nearest) { + nearest = idx; + d_nearest = dis; + } + }; + size_t ndis = 0; - for (size_t i = begin; i < end; i++, ndis++) { - storage_idx_t v = hnsw.neighbors[i]; + int n_buffered = 0; + storage_idx_t buffered_ids[4]; + + for (size_t j = begin; j < end; j++) { + storage_idx_t v = hnsw.neighbors[j]; if (v < 0) break; - float dis = qdis(v); - if (dis < d_nearest) { - nearest = v; - d_nearest = dis; + ndis += 1; + + buffered_ids[n_buffered] = v; + n_buffered += 1; + + if (n_buffered == 4) { + float dis[4]; + qdis.distances_batch_4( + buffered_ids[0], + buffered_ids[1], + buffered_ids[2], + buffered_ids[3], + dis[0], + dis[1], + dis[2], + dis[3]); + + for (size_t id4 = 0; id4 < 4; id4++) { + update_with_candidate(buffered_ids[id4], dis[id4]); + } + + n_buffered = 0; } } + // process leftovers + for (size_t icnt = 0; icnt < n_buffered; icnt++) { + float dis = qdis(buffered_ids[icnt]); + update_with_candidate(buffered_ids[icnt], dis); + } + // update stats stats.ndis += ndis; stats.nhops += 1; @@ -631,7 +736,7 @@ int search_from_candidates( // the following version processes 4 neighbors at a time size_t jmax = begin; for (size_t j = begin; j < end; j++) { - int v1 = hnsw.neighbors[j]; + storage_idx_t v1 = hnsw.neighbors[j]; if (v1 < 0) break; @@ -639,13 +744,13 @@ int search_from_candidates( jmax += 1; } - int counter = 0; - size_t saved_j[4]; + int n_buffered = 0; + storage_idx_t buffered_ids[4]; ndis += jmax - begin; threshold = res.threshold; - auto add_to_heap = [&](const size_t idx, const float dis) { + auto add_to_heap = [&](const storage_idx_t idx, const float dis) { if (!sel || sel->is_member(idx)) { if (dis < threshold) { if (res.add_result(dis, idx)) { @@ -658,36 +763,40 @@ int search_from_candidates( }; for (size_t j = begin; j < jmax; j++) { - int v1 = hnsw.neighbors[j]; + storage_idx_t v1 = hnsw.neighbors[j]; - bool vget = vt.get(v1); + if (vt.get(v1)) { + continue; + } vt.set(v1); - saved_j[counter] = v1; - counter += vget ? 0 : 1; - if (counter == 4) { + buffered_ids[n_buffered] = v1; + n_buffered += 1; + + if (n_buffered == 4) { float dis[4]; qdis.distances_batch_4( - saved_j[0], - saved_j[1], - saved_j[2], - saved_j[3], + buffered_ids[0], + buffered_ids[1], + buffered_ids[2], + buffered_ids[3], dis[0], dis[1], dis[2], dis[3]); for (size_t id4 = 0; id4 < 4; id4++) { - add_to_heap(saved_j[id4], dis[id4]); + add_to_heap(buffered_ids[id4], dis[id4]); } - counter = 0; + n_buffered = 0; } } - for (size_t icnt = 0; icnt < counter; icnt++) { - float dis = qdis(saved_j[icnt]); - add_to_heap(saved_j[icnt], dis); + // process leftovers + for (size_t icnt = 0; icnt < n_buffered; icnt++) { + float dis = qdis(buffered_ids[icnt]); + add_to_heap(buffered_ids[icnt], dis); } nstep++; From 094fa745ef20472ff6f4c7a5fc5f79a81be0c44a Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Mon, 19 Aug 2024 09:55:24 -0400 Subject: [PATCH 15/21] increase range search default param (PR 764) Signed-off-by: Alexandr Guzhva --- src/index/hnsw/faiss_hnsw_config.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/index/hnsw/faiss_hnsw_config.h b/src/index/hnsw/faiss_hnsw_config.h index 514cd8ef4..129331d3b 100644 --- a/src/index/hnsw/faiss_hnsw_config.h +++ b/src/index/hnsw/faiss_hnsw_config.h @@ -21,7 +21,7 @@ namespace { constexpr const CFG_INT::value_type kIteratorSeedEf = 40; constexpr const CFG_INT::value_type kEfMinValue = 16; -constexpr const CFG_INT::value_type kDefaultRangeSearchEf = 16; +constexpr const CFG_INT::value_type kDefaultRangeSearchEf = 512; } // namespace From 413c8f8214ffbeac2e607de41166fb7b06dfd996 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Tue, 20 Aug 2024 11:39:48 -0400 Subject: [PATCH 16/21] lower chunk size in faiss_hnsw.cc :: add_to_index() Signed-off-by: Alexandr Guzhva --- src/index/hnsw/faiss_hnsw.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index a3f3c77c0..9f98974fb 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -353,7 +353,7 @@ add_to_index(faiss::Index* const __restrict index, const DataSetPtr& dataset, co index->add(rows, reinterpret_cast(data)); } else { // convert data into float in pieces and add to the index - constexpr int64_t n_tmp_rows = 65536; + constexpr int64_t n_tmp_rows = 4096; std::vector tmp(n_tmp_rows * dim); for (int64_t irow = 0; irow < rows; irow += n_tmp_rows) { From 42555789b1ccc63c16c8d684f55852a1332dea2b Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Tue, 20 Aug 2024 12:59:50 -0400 Subject: [PATCH 17/21] GetVectorByIds() returns data in an appropriate data format Signed-off-by: Alexandr Guzhva --- include/knowhere/dataset.h | 22 +---- src/index/hnsw/faiss_hnsw.cc | 169 +++++++++++++++++++++++++++-------- 2 files changed, 137 insertions(+), 54 deletions(-) diff --git a/include/knowhere/dataset.h b/include/knowhere/dataset.h index 8e443e649..5ad4f7df2 100644 --- a/include/knowhere/dataset.h +++ b/include/knowhere/dataset.h @@ -123,14 +123,9 @@ class DataSet : public std::enable_shared_from_this { this->data_[meta::TENSOR] = Var(std::in_place_index<3>, tensor); } + template void - SetTensor(std::unique_ptr&& tensor) { - std::unique_lock lock(mutex_); - this->data_[meta::TENSOR] = Var(std::in_place_index<3>, tensor.release()); - } - - void - SetTensor(std::unique_ptr&& tensor) { + SetTensor(std::unique_ptr&& tensor) { std::unique_lock lock(mutex_); this->data_[meta::TENSOR] = Var(std::in_place_index<3>, tensor.release()); } @@ -326,18 +321,9 @@ GenResultDataSet(const int64_t rows, const int64_t dim, const void* tensor) { return ret_ds; } +template inline DataSetPtr -GenResultDataSet(const int64_t rows, const int64_t dim, std::unique_ptr&& tensor) { - auto ret_ds = std::make_shared(); - ret_ds->SetRows(rows); - ret_ds->SetDim(dim); - ret_ds->SetTensor(std::move(tensor)); - ret_ds->SetIsOwner(true); - return ret_ds; -} - -inline DataSetPtr -GenResultDataSet(const int64_t rows, const int64_t dim, std::unique_ptr&& tensor) { +GenResultDataSet(const int64_t rows, const int64_t dim, std::unique_ptr&& tensor) { auto ret_ds = std::make_shared(); ret_ds->SetRows(rows); ret_ds->SetDim(dim); diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 9f98974fb..a774958c3 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -147,35 +147,6 @@ class BaseFaissRegularIndexNode : public BaseFaissIndexNode { : BaseFaissIndexNode(version, object), index{nullptr} { } - expected - GetVectorByIds(const DataSetPtr dataset) const override { - if (this->index == nullptr) { - return expected::Err(Status::empty_index, "index not loaded"); - } - if (!this->index->is_trained) { - return expected::Err(Status::index_not_trained, "index not trained"); - } - - auto dim = Dim(); - auto rows = dataset->GetRows(); - auto ids = dataset->GetIds(); - - try { - auto data = std::make_unique(dim * rows); - - for (int64_t i = 0; i < rows; i++) { - const int64_t id = ids[i]; - assert(id >= 0 && id < index->ntotal); - index->reconstruct(id, data.get() + i * dim); - } - - return GenResultDataSet(rows, dim, std::move(data)); - } catch (const std::exception& e) { - LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); - return expected::Err(Status::faiss_inner_error, e.what()); - } - } - Status Serialize(BinarySet& binset) const override { if (index == nullptr) { @@ -306,21 +277,63 @@ namespace { // bool -convert_rows(const void* const __restrict src_in, float* const __restrict dst, const DataFormatEnum data_format, +convert_rows_to_fp32(const void* const __restrict src_in, float* const __restrict dst, const DataFormatEnum src_data_format, const size_t start_row, const size_t nrows, const size_t dim) { - if (data_format == DataFormatEnum::fp16) { + if (src_data_format == DataFormatEnum::fp16) { const knowhere::fp16* const src = reinterpret_cast(src_in); for (size_t i = 0; i < nrows * dim; i++) { dst[i] = (float)(src[i + start_row * dim]); } return true; - } else if (data_format == DataFormatEnum::bf16) { + } else if (src_data_format == DataFormatEnum::bf16) { const knowhere::bf16* const src = reinterpret_cast(src_in); for (size_t i = 0; i < nrows * dim; i++) { dst[i] = (float)(src[i + start_row * dim]); } + return true; + } else if (src_data_format == DataFormatEnum::fp32) { + const knowhere::fp32* const src = reinterpret_cast(src_in); + for (size_t i = 0; i < nrows * dim; i++) { + dst[i] = src[i + start_row * dim]; + } + + return true; + } else { + // unknown + return false; + } +} + +bool convert_rows_from_fp32( + const float* const __restrict src, + void* const __restrict dst_in, + const DataFormatEnum dst_data_format, + const size_t start_row, + const size_t nrows, + const size_t dim +) { + if (dst_data_format == DataFormatEnum::fp16) { + knowhere::fp16* const dst = reinterpret_cast(dst_in); + for (size_t i = 0; i < nrows * dim; i++) { + dst[i + start_row * dim] = (knowhere::fp16)src[i]; + } + + return true; + } else if (dst_data_format == DataFormatEnum::bf16) { + knowhere::bf16* const dst = reinterpret_cast(dst_in); + for (size_t i = 0; i < nrows * dim; i++) { + dst[i + start_row * dim] = (knowhere::bf16)src[i]; + } + + return true; + } else if (dst_data_format == DataFormatEnum::fp32) { + knowhere::fp32* const dst = reinterpret_cast(dst_in); + for (size_t i = 0; i < nrows * dim; i++) { + dst[i + start_row * dim] = src[i]; + } + return true; } else { // unknown @@ -354,20 +367,20 @@ add_to_index(faiss::Index* const __restrict index, const DataSetPtr& dataset, co } else { // convert data into float in pieces and add to the index constexpr int64_t n_tmp_rows = 4096; - std::vector tmp(n_tmp_rows * dim); + std::unique_ptr tmp = std::make_unique(n_tmp_rows * dim); for (int64_t irow = 0; irow < rows; irow += n_tmp_rows) { const int64_t start_row = irow; const int64_t end_row = std::min(rows, start_row + n_tmp_rows); const int64_t count_rows = end_row - start_row; - if (!convert_rows(data, tmp.data(), data_format, start_row, count_rows, dim)) { + if (!convert_rows_to_fp32(data, tmp.get(), data_format, start_row, count_rows, dim)) { LOG_KNOWHERE_ERROR_ << "Unsupported data format"; return Status::invalid_args; } // add - index->add(count_rows, tmp.data()); + index->add(count_rows, tmp.get()); } } @@ -455,7 +468,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { if (data_format == DataFormatEnum::fp32) { cur_query = (const float*)data + idx * dim; } else { - convert_rows(data, cur_query_tmp.data(), data_format, idx, 1, dim); + convert_rows_to_fp32(data, cur_query_tmp.data(), data_format, idx, 1, dim); cur_query = cur_query_tmp.data(); } @@ -613,6 +626,90 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { return res; } + expected + GetVectorByIds(const DataSetPtr dataset) const override { + if (this->index == nullptr) { + return expected::Err(Status::empty_index, "index not loaded"); + } + if (!this->index->is_trained) { + return expected::Err(Status::index_not_trained, "index not trained"); + } + + auto dim = Dim(); + auto rows = dataset->GetRows(); + auto ids = dataset->GetIds(); + + try { + if (data_format == DataFormatEnum::fp32) { + // perform a direct reconstruction for fp32 data + auto data = std::make_unique(dim * rows); + + for (int64_t i = 0; i < rows; i++) { + const int64_t id = ids[i]; + assert(id >= 0 && id < index->ntotal); + index->reconstruct(id, data.get() + i * dim); + } + + return GenResultDataSet(rows, dim, std::move(data)); + } else if (data_format == DataFormatEnum::fp16) { + auto data = std::make_unique(dim * rows); + + // faiss produces fp32 data format, we need some other format. + // Let's create a temporary fp32 buffer for this. + auto tmp = std::make_unique(dim); + + for (int64_t i = 0; i < rows; i++) { + const int64_t id = ids[i]; + assert(id >= 0 && id < index->ntotal); + index->reconstruct(id, tmp.get()); + + if (!convert_rows_from_fp32( + tmp.get(), + data.get(), + data_format, + i, + 1, + dim) + ) { + return expected::Err(Status::invalid_args, "Unsupported data format"); + } + } + + return GenResultDataSet(rows, dim, std::move(data)); + } else if (data_format == DataFormatEnum::bf16) { + auto data = std::make_unique(dim * rows); + + // faiss produces fp32 data format, we need some other format. + // Let's create a temporary fp32 buffer for this. + auto tmp = std::make_unique(dim); + + for (int64_t i = 0; i < rows; i++) { + const int64_t id = ids[i]; + assert(id >= 0 && id < index->ntotal); + index->reconstruct(id, tmp.get()); + + if (!convert_rows_from_fp32( + tmp.get(), + data.get(), + data_format, + i, + 1, + dim) + ) { + return expected::Err(Status::invalid_args, "Unsupported data format"); + } + } + + return GenResultDataSet(rows, dim, std::move(data)); + } else { + return expected::Err(Status::invalid_args, "Unsupported data format"); + } + } catch (const std::exception& e) { + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); + return expected::Err(Status::faiss_inner_error, e.what()); + } + } + protected: DataFormatEnum data_format; From aefddad065d0edd95019c27f030cb1137f061d57 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Tue, 20 Aug 2024 14:16:13 -0400 Subject: [PATCH 18/21] add lower-case comparisons Signed-off-by: Alexandr Guzhva --- include/knowhere/tolower.h | 33 ++++++++++++++++++++++++++++++ src/index/hnsw/faiss_hnsw.cc | 18 ++++++++-------- src/index/hnsw/faiss_hnsw_config.h | 23 ++++++++------------- 3 files changed, 51 insertions(+), 23 deletions(-) create mode 100644 include/knowhere/tolower.h diff --git a/include/knowhere/tolower.h b/include/knowhere/tolower.h new file mode 100644 index 000000000..51be51761 --- /dev/null +++ b/include/knowhere/tolower.h @@ -0,0 +1,33 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// valributed under the License is valributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#pragma once + +#include +#include + +namespace knowhere { + +static inline std::string str_to_lower(std::string s) { + std::transform( + s.begin(), + s.end(), + s.begin(), + [](unsigned char c){ return std::tolower(c); } + ); + + return s; +} + +} diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index a774958c3..9c3a04bb4 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -878,16 +878,17 @@ namespace { // a supporting function expected get_sq_quantizer_type(const std::string& sq_type) { - std::map sq_types = {{"SQ6", faiss::ScalarQuantizer::QT_6bit}, - {"SQ8", faiss::ScalarQuantizer::QT_8bit}, - {"FP16", faiss::ScalarQuantizer::QT_fp16}, - {"BF16", faiss::ScalarQuantizer::QT_bf16}}; + std::map sq_types = {{"sq6", faiss::ScalarQuantizer::QT_6bit}, + {"sq8", faiss::ScalarQuantizer::QT_8bit}, + {"fp16", faiss::ScalarQuantizer::QT_fp16}, + {"bf16", faiss::ScalarQuantizer::QT_bf16}}; // todo: tolower - auto itr = sq_types.find(sq_type); + auto sq_type_tolower = str_to_lower(sq_type); + auto itr = sq_types.find(sq_type_tolower); if (itr == sq_types.cend()) { return expected::Err( - Status::invalid_args, fmt::format("invalid scalar quantizer type ({})", sq_type)); + Status::invalid_args, fmt::format("invalid scalar quantizer type ({})", sq_type_tolower)); } return itr->second; @@ -901,12 +902,13 @@ is_flat_refine(const std::optional& refine_type) { }; // todo: tolower - if (refine_type.value() == "FP32" || refine_type.value() == "FLAT") { + std::string refine_type_tolower = str_to_lower(refine_type.value()); + if (refine_type_tolower == "fp32" || refine_type_tolower == "flat") { return true; }; // parse - auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); + auto refine_sq_type = get_sq_quantizer_type(refine_type_tolower); if (!refine_sq_type.has_value()) { LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << refine_type.value(); return expected::Err(Status::invalid_args, fmt::format("invalid refine type ({})", refine_type.value())); diff --git a/src/index/hnsw/faiss_hnsw_config.h b/src/index/hnsw/faiss_hnsw_config.h index 129331d3b..abbf49c63 100644 --- a/src/index/hnsw/faiss_hnsw_config.h +++ b/src/index/hnsw/faiss_hnsw_config.h @@ -14,6 +14,7 @@ #include "knowhere/comp/index_param.h" #include "knowhere/config.h" +#include "knowhere/tolower.h" namespace knowhere { @@ -118,12 +119,11 @@ class FaissHnswConfig : public BaseConfig { bool WhetherAcceptableRefineType(const std::string& refine_type) { // 'flat' is identical to 'fp32' - std::vector allowed_list = {"SQ6", "SQ8", "FP16", "BF16", "FP32", "FLAT"}; - - // todo: tolower() + std::vector allowed_list = {"sq6", "sq8", "fp16", "bf16", "fp32", "flat"}; + std::string refine_type_tolower = str_to_lower(refine_type); for (const auto& allowed : allowed_list) { - if (refine_type == allowed) { + if (refine_type_tolower == allowed) { return true; } } @@ -145,13 +145,7 @@ class FaissHnswFlatConfig : public FaissHnswConfig { // check our parameters if (param_type == PARAM_TYPE::TRAIN) { // prohibit refine - if (refine.value_or(false)) { - *err_msg = "refine is not supported for this index"; - LOG_KNOWHERE_ERROR_ << *err_msg; - return Status::invalid_value_in_json; - } - - if (refine_type.has_value()) { + if (refine.value_or(false) || refine_type.has_value() || refine_k.has_value()) { *err_msg = "refine is not supported for this index"; LOG_KNOWHERE_ERROR_ << *err_msg; return Status::invalid_value_in_json; @@ -203,12 +197,11 @@ class FaissHnswSqConfig : public FaissHnswConfig { bool WhetherAcceptableQuantType(const std::string& sq_type) { // todo: add more - std::vector allowed_list = {"SQ6", "SQ8", "FP16", "BF16"}; - - // todo: tolower() + std::vector allowed_list = {"sq6", "sq8", "fp16", "bf16"}; + std::string sq_type_tolower = str_to_lower(sq_type); for (const auto& allowed : allowed_list) { - if (sq_type == allowed) { + if (sq_type_tolower == allowed) { return true; } } From 708b63685ced592b51e55c2b0916a94c3fc2a8dc Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Tue, 20 Aug 2024 14:20:48 -0400 Subject: [PATCH 19/21] add a missing check for refine_value.has_value() Signed-off-by: Alexandr Guzhva --- src/index/hnsw/faiss_hnsw.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 9c3a04bb4..90e095a89 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -1053,7 +1053,7 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { // should refine be used? std::unique_ptr final_index; - if (hnsw_cfg.refine.value_or(false)) { + if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { // yes auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); if (!final_index_cnd.has_value()) { @@ -1158,7 +1158,7 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { // should refine be used? std::unique_ptr final_index; - if (hnsw_cfg.refine.value_or(false)) { + if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { // yes auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); if (!final_index_cnd.has_value()) { @@ -1352,7 +1352,7 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { // should refine be used? std::unique_ptr final_index; - if (hnsw_cfg.refine.value_or(false)) { + if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { // yes auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); if (!final_index_cnd.has_value()) { From a164d64ce4f56176df25dd68eb81d0957772a50c Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Tue, 20 Aug 2024 14:55:43 -0400 Subject: [PATCH 20/21] apply clang format Signed-off-by: Alexandr Guzhva --- include/knowhere/dataset.h | 4 ++-- include/knowhere/tolower.h | 12 ++++------ src/index/hnsw/faiss_hnsw.cc | 43 +++++++++++------------------------- 3 files changed, 19 insertions(+), 40 deletions(-) diff --git a/include/knowhere/dataset.h b/include/knowhere/dataset.h index 5ad4f7df2..1a7b37169 100644 --- a/include/knowhere/dataset.h +++ b/include/knowhere/dataset.h @@ -123,7 +123,7 @@ class DataSet : public std::enable_shared_from_this { this->data_[meta::TENSOR] = Var(std::in_place_index<3>, tensor); } - template + template void SetTensor(std::unique_ptr&& tensor) { std::unique_lock lock(mutex_); @@ -321,7 +321,7 @@ GenResultDataSet(const int64_t rows, const int64_t dim, const void* tensor) { return ret_ds; } -template +template inline DataSetPtr GenResultDataSet(const int64_t rows, const int64_t dim, std::unique_ptr&& tensor) { auto ret_ds = std::make_shared(); diff --git a/include/knowhere/tolower.h b/include/knowhere/tolower.h index 51be51761..87c0dbef3 100644 --- a/include/knowhere/tolower.h +++ b/include/knowhere/tolower.h @@ -19,15 +19,11 @@ namespace knowhere { -static inline std::string str_to_lower(std::string s) { - std::transform( - s.begin(), - s.end(), - s.begin(), - [](unsigned char c){ return std::tolower(c); } - ); +static inline std::string +str_to_lower(std::string s) { + std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); }); return s; } -} +} // namespace knowhere diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 90e095a89..fda8b2abf 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -277,8 +277,9 @@ namespace { // bool -convert_rows_to_fp32(const void* const __restrict src_in, float* const __restrict dst, const DataFormatEnum src_data_format, - const size_t start_row, const size_t nrows, const size_t dim) { +convert_rows_to_fp32(const void* const __restrict src_in, float* const __restrict dst, + const DataFormatEnum src_data_format, const size_t start_row, const size_t nrows, + const size_t dim) { if (src_data_format == DataFormatEnum::fp16) { const knowhere::fp16* const src = reinterpret_cast(src_in); for (size_t i = 0; i < nrows * dim; i++) { @@ -306,14 +307,10 @@ convert_rows_to_fp32(const void* const __restrict src_in, float* const __restric } } -bool convert_rows_from_fp32( - const float* const __restrict src, - void* const __restrict dst_in, - const DataFormatEnum dst_data_format, - const size_t start_row, - const size_t nrows, - const size_t dim -) { +bool +convert_rows_from_fp32(const float* const __restrict src, void* const __restrict dst_in, + const DataFormatEnum dst_data_format, const size_t start_row, const size_t nrows, + const size_t dim) { if (dst_data_format == DataFormatEnum::fp16) { knowhere::fp16* const dst = reinterpret_cast(dst_in); for (size_t i = 0; i < nrows * dim; i++) { @@ -652,8 +649,8 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { return GenResultDataSet(rows, dim, std::move(data)); } else if (data_format == DataFormatEnum::fp16) { - auto data = std::make_unique(dim * rows); - + auto data = std::make_unique(dim * rows); + // faiss produces fp32 data format, we need some other format. // Let's create a temporary fp32 buffer for this. auto tmp = std::make_unique(dim); @@ -663,22 +660,15 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { assert(id >= 0 && id < index->ntotal); index->reconstruct(id, tmp.get()); - if (!convert_rows_from_fp32( - tmp.get(), - data.get(), - data_format, - i, - 1, - dim) - ) { + if (!convert_rows_from_fp32(tmp.get(), data.get(), data_format, i, 1, dim)) { return expected::Err(Status::invalid_args, "Unsupported data format"); } } return GenResultDataSet(rows, dim, std::move(data)); } else if (data_format == DataFormatEnum::bf16) { - auto data = std::make_unique(dim * rows); - + auto data = std::make_unique(dim * rows); + // faiss produces fp32 data format, we need some other format. // Let's create a temporary fp32 buffer for this. auto tmp = std::make_unique(dim); @@ -688,14 +678,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { assert(id >= 0 && id < index->ntotal); index->reconstruct(id, tmp.get()); - if (!convert_rows_from_fp32( - tmp.get(), - data.get(), - data_format, - i, - 1, - dim) - ) { + if (!convert_rows_from_fp32(tmp.get(), data.get(), data_format, i, 1, dim)) { return expected::Err(Status::invalid_args, "Unsupported data format"); } } From 924049ae8b1232a67739632a78d272805b34c837 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Tue, 20 Aug 2024 18:47:49 -0400 Subject: [PATCH 21/21] remove a layer of unneeded filters; relying on compiler de-virtualization Signed-off-by: Alexandr Guzhva --- include/knowhere/bitsetview_idselector.h | 4 +- src/index/hnsw/impl/BitsetFilter.h | 35 --------- src/index/hnsw/impl/IndexBruteForceWrapper.cc | 50 ++++++++++--- src/index/hnsw/impl/IndexHNSWWrapper.cc | 48 +++++++------ .../knowhere/IndexBruteForceWrapper.cpp | 72 ++++++++++++------- .../cppcontrib/knowhere/IndexHNSWWrapper.cpp | 67 ++++++++++------- .../cppcontrib/knowhere/impl/Bruteforce.h | 6 +- .../faiss/cppcontrib/knowhere/impl/Filters.h | 44 ------------ .../cppcontrib/knowhere/impl/HnswSearcher.h | 6 +- thirdparty/faiss/faiss/impl/IDSelector.h | 2 +- thirdparty/faiss/faiss/utils/distances.cpp | 2 +- 11 files changed, 164 insertions(+), 172 deletions(-) delete mode 100644 src/index/hnsw/impl/BitsetFilter.h delete mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Filters.h diff --git a/include/knowhere/bitsetview_idselector.h b/include/knowhere/bitsetview_idselector.h index 163e58cca..c09aba222 100644 --- a/include/knowhere/bitsetview_idselector.h +++ b/include/knowhere/bitsetview_idselector.h @@ -17,8 +17,8 @@ namespace knowhere { -struct BitsetViewIDSelector : faiss::IDSelector { - BitsetView bitset_view; +struct BitsetViewIDSelector final : faiss::IDSelector { + const BitsetView bitset_view; inline BitsetViewIDSelector(BitsetView bitset_view) : bitset_view{bitset_view} { } diff --git a/src/index/hnsw/impl/BitsetFilter.h b/src/index/hnsw/impl/BitsetFilter.h deleted file mode 100644 index 7b3e5b88b..000000000 --- a/src/index/hnsw/impl/BitsetFilter.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (C) 2019-2024 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include - -#include "knowhere/bitsetview.h" - -namespace knowhere { - -// specialized override for knowhere -struct BitsetFilter { - // contains disabled nodes. - knowhere::BitsetView bitset_view; - - inline BitsetFilter(knowhere::BitsetView bitset_view_) : bitset_view{bitset_view_} { - } - - inline bool - allowed(const faiss::idx_t idx) const { - // there's no check for bitset_view.empty() by design - return !bitset_view.test(idx); - } -}; - -} // namespace knowhere diff --git a/src/index/hnsw/impl/IndexBruteForceWrapper.cc b/src/index/hnsw/impl/IndexBruteForceWrapper.cc index ffd3f74cf..e9084dde7 100644 --- a/src/index/hnsw/impl/IndexBruteForceWrapper.cc +++ b/src/index/hnsw/impl/IndexBruteForceWrapper.cc @@ -21,7 +21,6 @@ #include #include -#include "index/hnsw/impl/BitsetFilter.h" #include "knowhere/bitsetview.h" #include "knowhere/bitsetview_idselector.h" @@ -29,6 +28,21 @@ namespace knowhere { using idx_t = faiss::idx_t; +// the following structure is a hack, because GCC cannot properly +// de-virtualize a plain BitsetViewIDSelector. +struct BitsetViewIDSelectorWrapper final { + const BitsetView bitset_view; + + inline BitsetViewIDSelectorWrapper(BitsetView bitset_view) : bitset_view{bitset_view} { + } + + [[nodiscard]] inline bool + is_member(faiss::idx_t id) const { + // it is by design that bitset_view.empty() is not tested here + return (!bitset_view.test(id)); + } +}; + // IndexBruteForceWrapper::IndexBruteForceWrapper(faiss::Index* underlying_index) : faiss::cppcontrib::knowhere::IndexWrapper{underlying_index} { @@ -36,7 +50,8 @@ IndexBruteForceWrapper::IndexBruteForceWrapper(faiss::Index* underlying_index) void IndexBruteForceWrapper::search(faiss::idx_t n, const float* __restrict x, faiss::idx_t k, float* __restrict distances, - faiss::idx_t* __restrict labels, const faiss::SearchParameters* params) const { + faiss::idx_t* __restrict labels, + const faiss::SearchParameters* __restrict params) const { FAISS_THROW_IF_NOT(k > 0); std::unique_ptr dis(index->get_distance_computer()); @@ -47,8 +62,8 @@ IndexBruteForceWrapper::search(faiss::idx_t n, const float* __restrict x, faiss: dis->set_query(x + i * index->d); // allocate heap - idx_t* const local_ids = labels + i * index->d; - float* const local_distances = distances + i * index->d; + idx_t* const __restrict local_ids = labels + i * index->d; + float* const __restrict local_distances = distances + i * index->d; // set up a filter faiss::IDSelector* sel = (params == nullptr) ? nullptr : params->sel; @@ -59,20 +74,35 @@ IndexBruteForceWrapper::search(faiss::idx_t n, const float* __restrict x, faiss: } // try knowhere-specific filter - const knowhere::BitsetViewIDSelector* bw_idselector = dynamic_cast(sel); + const knowhere::BitsetViewIDSelector* __restrict bw_idselector = + dynamic_cast(sel); - knowhere::BitsetFilter filter(bw_idselector->bitset_view); + BitsetViewIDSelectorWrapper bw_idselector_w(bw_idselector->bitset_view); if (is_similarity_metric(index->metric_type)) { using C = faiss::CMin; - faiss::cppcontrib::knowhere::brute_force_search_impl( - index->ntotal, *dis, filter, k, local_distances, local_ids); + if (bw_idselector == nullptr || bw_idselector->bitset_view.empty()) { + faiss::IDSelectorAll sel_all; + faiss::cppcontrib::knowhere::brute_force_search_impl( + index->ntotal, *dis, sel_all, k, local_distances, local_ids); + } else { + faiss::cppcontrib::knowhere::brute_force_search_impl( + index->ntotal, *dis, bw_idselector_w, k, local_distances, local_ids); + } } else { using C = faiss::CMax; - faiss::cppcontrib::knowhere::brute_force_search_impl( - index->ntotal, *dis, filter, k, local_distances, local_ids); + if (bw_idselector == nullptr || bw_idselector->bitset_view.empty()) { + faiss::IDSelectorAll sel_all; + faiss::cppcontrib::knowhere::brute_force_search_impl( + index->ntotal, *dis, sel_all, k, local_distances, local_ids); + } else { + faiss::cppcontrib::knowhere::brute_force_search_impl( + index->ntotal, *dis, bw_idselector_w, k, local_distances, local_ids); + } } } } diff --git a/src/index/hnsw/impl/IndexHNSWWrapper.cc b/src/index/hnsw/impl/IndexHNSWWrapper.cc index b8d22a31a..bfe101099 100644 --- a/src/index/hnsw/impl/IndexHNSWWrapper.cc +++ b/src/index/hnsw/impl/IndexHNSWWrapper.cc @@ -14,7 +14,6 @@ #include #include #include -#include #include #include #include @@ -27,7 +26,6 @@ #include #include -#include "index/hnsw/impl/BitsetFilter.h" #include "index/hnsw/impl/FederVisitor.h" #include "knowhere/bitsetview.h" #include "knowhere/bitsetview_idselector.h" @@ -83,7 +81,7 @@ IndexHNSWWrapper::IndexHNSWWrapper(faiss::IndexHNSW* underlying_index) void IndexHNSWWrapper::search(idx_t n, const float* __restrict x, idx_t k, float* __restrict distances, - idx_t* __restrict labels, const faiss::SearchParameters* params_in) const { + idx_t* __restrict labels, const faiss::SearchParameters* __restrict params_in) const { FAISS_THROW_IF_NOT(k > 0); const faiss::IndexHNSW* index_hnsw = dynamic_cast(index); @@ -150,61 +148,65 @@ IndexHNSWWrapper::search(idx_t n, const float* __restrict x, idx_t k, float* __r faiss::IDSelector* sel = (params == nullptr) ? nullptr : params->sel; // try knowhere-specific filter - const knowhere::BitsetViewIDSelector* bw_idselector = dynamic_cast(sel); + const knowhere::BitsetViewIDSelector* __restrict bw_idselector = + dynamic_cast(sel); - if (bw_idselector == nullptr) { + if (bw_idselector == nullptr || bw_idselector->bitset_view.empty()) { // no filter - faiss::cppcontrib::knowhere::AllowAllFilter filter; + faiss::IDSelectorAll sel_all; // feder templating is important, bcz it removes an unneeded 'CALL' instruction. if (feder == nullptr) { // no feder DummyVisitor graph_visitor; - using searcher_type = - faiss::cppcontrib::knowhere::v2_hnsw_searcher; + using searcher_type = faiss::cppcontrib::knowhere::v2_hnsw_searcher< + faiss::DistanceComputer, DummyVisitor, faiss::cppcontrib::knowhere::Bitset, faiss::IDSelectorAll>; - searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, filter, kAlpha, params}; + searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, + sel_all, kAlpha, params}; local_stats = searcher.search(k, distances + i * k, labels + i * k); } else { // use feder FederVisitor graph_visitor(feder); - using searcher_type = - faiss::cppcontrib::knowhere::v2_hnsw_searcher; + using searcher_type = faiss::cppcontrib::knowhere::v2_hnsw_searcher< + faiss::DistanceComputer, FederVisitor, faiss::cppcontrib::knowhere::Bitset, faiss::IDSelectorAll>; - searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, filter, kAlpha, params}; + searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, + sel_all, kAlpha, params}; local_stats = searcher.search(k, distances + i * k, labels + i * k); } } else { // with filter - knowhere::BitsetFilter filter(bw_idselector->bitset_view); // feder templating is important, bcz it removes an unneeded 'CALL' instruction. if (feder == nullptr) { // no feder DummyVisitor graph_visitor; - using searcher_type = faiss::cppcontrib::knowhere::v2_hnsw_searcher< - faiss::DistanceComputer, DummyVisitor, faiss::cppcontrib::knowhere::Bitset, knowhere::BitsetFilter>; + using searcher_type = + faiss::cppcontrib::knowhere::v2_hnsw_searcher; - searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, filter, kAlpha, params}; + searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, + *bw_idselector, kAlpha, params}; local_stats = searcher.search(k, distances + i * k, labels + i * k); } else { // use feder FederVisitor graph_visitor(feder); - using searcher_type = faiss::cppcontrib::knowhere::v2_hnsw_searcher< - faiss::DistanceComputer, FederVisitor, faiss::cppcontrib::knowhere::Bitset, knowhere::BitsetFilter>; + using searcher_type = + faiss::cppcontrib::knowhere::v2_hnsw_searcher; - searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, filter, kAlpha, params}; + searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, + *bw_idselector, kAlpha, params}; local_stats = searcher.search(k, distances + i * k, labels + i * k); } diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.cpp b/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.cpp index 4c6cad23d..68b6d4d5e 100644 --- a/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.cpp +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.cpp @@ -22,7 +22,6 @@ #include #include -#include namespace faiss { namespace cppcontrib { @@ -37,7 +36,7 @@ void IndexBruteForceWrapper::search( idx_t k, float* __restrict distances, idx_t* __restrict labels, - const SearchParameters* params + const SearchParameters* __restrict params ) const { FAISS_THROW_IF_NOT(k > 0); @@ -57,39 +56,60 @@ void IndexBruteForceWrapper::search( dis->set_query(x + i * index->d); // allocate heap - idx_t* const local_ids = labels + i * index->d; - float* const local_distances = distances + i * index->d; + idx_t* const __restrict local_ids = labels + i * index->d; + float* const __restrict local_distances = distances + i * index->d; // set up a filter - IDSelector* sel = (params == nullptr) ? nullptr : params->sel; - - // template, just in case a filter type will be specialized - // in order to remove virtual function call overhead. - using filter_type = DefaultIDSelectorFilter; - filter_type filter(sel); + IDSelector* __restrict sel = (params == nullptr) ? nullptr : params->sel; if (is_similarity_metric(index->metric_type)) { using C = CMin; - brute_force_search_impl( - index->ntotal, - *dis, - filter, - k, - local_distances, - local_ids - ); + if (sel == nullptr) { + // Compiler is expected to de-virtualize virtual method calls + IDSelectorAll sel_all; + brute_force_search_impl( + index->ntotal, + *dis, + sel_all, + k, + local_distances, + local_ids + ); + } else { + brute_force_search_impl( + index->ntotal, + *dis, + *sel, + k, + local_distances, + local_ids + ); + } } else { using C = CMax; - brute_force_search_impl( - index->ntotal, - *dis, - filter, - k, - local_distances, - local_ids - ); + if (sel == nullptr) { + // Compiler is expected to de-virtualize virtual method calls + IDSelectorAll sel_all; + brute_force_search_impl( + index->ntotal, + *dis, + sel_all, + k, + local_distances, + local_ids + ); + } else { + brute_force_search_impl( + index->ntotal, + *dis, + *sel, + k, + local_distances, + local_ids + ); + } } } } diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexHNSWWrapper.cpp b/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexHNSWWrapper.cpp index 6d56383e7..dcc20a4c8 100644 --- a/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexHNSWWrapper.cpp +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/IndexHNSWWrapper.cpp @@ -21,7 +21,6 @@ #include #include -#include #include #include @@ -156,29 +155,49 @@ void IndexHNSWWrapper::search( // set up a filter IDSelector* sel = (params == nullptr) ? nullptr : params->sel; - - // template, just in case a filter type will be specialized - // in order to remove virtual function call overhead. - using filter_type = DefaultIDSelectorFilter; - filter_type filter(sel); - - using searcher_type = v2_hnsw_searcher< - DistanceComputer, - DummyVisitor, - Bitset, - filter_type>; - - searcher_type searcher{ - hnsw, - *(dis.get()), - graph_visitor, - bitset_visited_nodes, - filter, - kAlpha, - params - }; - - local_stats = searcher.search(k, distances + i * k, labels + i * k); + if (sel == nullptr) { + // no filter. + // It it expected that a compile will be able to + // de-virtualize the class. + IDSelectorAll sel_all; + + using searcher_type = v2_hnsw_searcher< + DistanceComputer, + DummyVisitor, + Bitset, + IDSelectorAll>; + + searcher_type searcher{ + hnsw, + *(dis.get()), + graph_visitor, + bitset_visited_nodes, + sel_all, + kAlpha, + params + }; + + local_stats = searcher.search(k, distances + i * k, labels + i * k); + } else { + // there is a filter + using searcher_type = v2_hnsw_searcher< + DistanceComputer, + DummyVisitor, + Bitset, + IDSelector>; + + searcher_type searcher{ + hnsw, + *(dis.get()), + graph_visitor, + bitset_visited_nodes, + *sel, + kAlpha, + params + }; + + local_stats = searcher.search(k, distances + i * k, labels + i * k); + } // update stats if possible if (hnsw_stats != nullptr) { diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Bruteforce.h b/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Bruteforce.h index f13a35740..10073811a 100644 --- a/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Bruteforce.h +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Bruteforce.h @@ -28,8 +28,8 @@ namespace knowhere { template void brute_force_search_impl( const idx_t ntotal, - DistanceComputerT& qdis, - const FilterT filter, + DistanceComputerT& __restrict qdis, + const FilterT& __restrict filter, const idx_t k, float* __restrict distances, idx_t* __restrict labels @@ -40,7 +40,7 @@ void brute_force_search_impl( auto max_heap = std::make_unique[]>(k); idx_t n_added = 0; for (idx_t idx = 0; idx < ntotal; ++idx) { - if (filter.allowed(idx)) { + if (filter.is_member(idx)) { const float distance = qdis(idx); if (n_added < k) { n_added += 1; diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Filters.h b/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Filters.h deleted file mode 100644 index fce3a673b..000000000 --- a/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Filters.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (C) 2019-2024 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include - -namespace faiss { -namespace cppcontrib { -namespace knowhere { - -// a filter that selects according to an IDSelector -template -struct DefaultIDSelectorFilter { - // contains enabled nodes. - // a non-owning pointer. - const IDSelectorT* selector = nullptr; - - inline DefaultIDSelectorFilter(const IDSelectorT* selector_) : selector{selector_} {} - - inline bool allowed(const idx_t idx) const { - return ((selector == nullptr) || (selector->is_member(idx))); - } -}; - -// a filter that allows everything -struct AllowAllFilter { - constexpr inline bool allowed(const idx_t) const { - return true; - } -}; - - -} -} -} diff --git a/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/HnswSearcher.h b/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/HnswSearcher.h index 29501bfab..3bc624afe 100644 --- a/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/HnswSearcher.h +++ b/thirdparty/faiss/faiss/cppcontrib/knowhere/impl/HnswSearcher.h @@ -72,7 +72,7 @@ struct v2_hnsw_searcher { // a filter for disabled nodes. // the reference is not owned. - const FilterT filter; + const FilterT& filter; // parameter for the filtering const float kAlpha; @@ -205,7 +205,7 @@ struct v2_hnsw_searcher { // is the node disabled? int status = knowhere::Neighbor::kValid; - if (!filter.allowed(v1)) { + if (!filter.is_member(v1)) { // yes, disabled status = knowhere::Neighbor::kInvalid; @@ -381,7 +381,7 @@ struct v2_hnsw_searcher { // initialize retset with a single 'nearest' point { - if (!filter.allowed(nearest)) { + if (!filter.is_member(nearest)) { retset.insert(knowhere::Neighbor(nearest, d_nearest, knowhere::Neighbor::kInvalid)); } else { retset.insert(knowhere::Neighbor(nearest, d_nearest, knowhere::Neighbor::kValid)); diff --git a/thirdparty/faiss/faiss/impl/IDSelector.h b/thirdparty/faiss/faiss/impl/IDSelector.h index 4d02540a2..2966881ea 100644 --- a/thirdparty/faiss/faiss/impl/IDSelector.h +++ b/thirdparty/faiss/faiss/impl/IDSelector.h @@ -125,7 +125,7 @@ struct IDSelectorNot : IDSelector { /// selects all entries (useful for benchmarking) struct IDSelectorAll : IDSelector { - bool is_member(idx_t id) const final { + inline bool is_member(idx_t id) const final { return true; } virtual ~IDSelectorAll() {} diff --git a/thirdparty/faiss/faiss/utils/distances.cpp b/thirdparty/faiss/faiss/utils/distances.cpp index 6f4186443..1584d72e1 100644 --- a/thirdparty/faiss/faiss/utils/distances.cpp +++ b/thirdparty/faiss/faiss/utils/distances.cpp @@ -135,7 +135,7 @@ namespace { // may be useful if the lion's share of samples are filtered out. struct IDSelectorAll { - inline bool is_member(const size_t idx) const { + constexpr inline bool is_member(const size_t idx) const { return true; } };