From edc1cbdfca7d8f32ea92a508b89df95813389f13 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Thu, 1 Aug 2024 18:39:35 -0400 Subject: [PATCH] FAISS_HNSW_PQ Signed-off-by: Alexandr Guzhva --- include/knowhere/comp/index_param.h | 4 + src/index/hnsw/faiss_hnsw.cc | 365 ++++++++++++++++---- src/index/hnsw/faiss_hnsw_config.h | 21 ++ thirdparty/faiss/faiss/IndexCosine.cpp | 61 ++++ thirdparty/faiss/faiss/IndexCosine.h | 34 ++ thirdparty/faiss/faiss/impl/index_read.cpp | 21 +- thirdparty/faiss/faiss/impl/index_write.cpp | 13 + 7 files changed, 443 insertions(+), 76 deletions(-) diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index 2f71ac84a..c89a0ab5b 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -55,6 +55,7 @@ constexpr const char* INDEX_DISKANN = "DISKANN"; constexpr const char* INDEX_FAISS_HNSW_FLAT = "FAISS_HNSW_FLAT"; constexpr const char* INDEX_FAISS_HNSW_SQ = "FAISS_HNSW_SQ"; +constexpr const char* INDEX_FAISS_HNSW_PQ = "FAISS_HNSW_PQ"; constexpr const char* INDEX_SPARSE_INVERTED_INDEX = "SPARSE_INVERTED_INDEX"; constexpr const char* INDEX_SPARSE_WAND = "SPARSE_WAND"; @@ -153,6 +154,9 @@ constexpr const char* HNSW_M = "M"; constexpr const char* EF = "ef"; constexpr const char* OVERVIEW_LEVELS = "overview_levels"; +// FAISS additional Params +constexpr const char* SQ_TYPE = "sq_type"; // for IVF_SQ and HNSW_SQ + // Sparse Params constexpr const char* DROP_RATIO_BUILD = "drop_ratio_build"; constexpr const char* DROP_RATIO_SEARCH = "drop_ratio_search"; diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 07fe85c9f..cb8afcc5b 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -539,7 +539,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { Status AddInternal(const DataSetPtr dataset, const Config&) override { - if (this->index == nullptr) { + if (index == nullptr) { LOG_KNOWHERE_ERROR_ << "Can not add data to an empty index."; return Status::empty_index; } @@ -547,7 +547,9 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { auto data = dataset->GetTensor(); auto rows = dataset->GetRows(); try { - this->index->add(rows, reinterpret_cast(data)); + LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; + + index->add(rows, reinterpret_cast(data)); } catch (const std::exception& e) { LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return Status::faiss_inner_error; @@ -617,6 +619,8 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode { hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); // train + LOG_KNOWHERE_INFO_ << "Training HNSW Index"; + hnsw_index->train(rows, (const float*)data); // done @@ -625,6 +629,112 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode { } }; + +namespace { + +// a supporting function +expected +get_sq_quantizer_type(const std::string& sq_type) { + std::map sq_types = { + {"SQ6", faiss::ScalarQuantizer::QT_6bit}, + {"SQ8", faiss::ScalarQuantizer::QT_8bit}, + {"FP16", faiss::ScalarQuantizer::QT_fp16}, + {"BF16", faiss::ScalarQuantizer::QT_bf16} + }; + + // todo: tolower + auto itr = sq_types.find(sq_type); + if (itr == sq_types.cend()) { + return expected::Err( + Status::invalid_args, fmt::format("invalid scalar quantizer type ({})", sq_type)); + } + + return itr->second; +} + +expected +is_flat_refine(const std::optional& refine_type) { + // grab a type of a refine index + if (!refine_type.has_value()) { + return true; + }; + + // todo: tolower + if (refine_type.value() == "FP32" || refine_type.value() == "FLAT") { + return true; + }; + + // parse + auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); + if (!refine_sq_type.has_value()) { + LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << refine_type.value(); + return expected::Err( + Status::invalid_args, fmt::format("invalid refine type ({})", refine_type.value())); + } + + return false; +} + +// pick a refine index +expected> +pick_refine_index(const std::optional& refine_type, std::unique_ptr&& hnsw_index) { + // yes + + // grab a type of a refine index + expected is_fp32_flat = is_flat_refine(refine_type); + if (!is_fp32_flat.has_value()) { + return expected>::Err( + Status::invalid_args, ""); + } + + const bool is_fp32_flat_v = is_fp32_flat.value(); + + std::unique_ptr local_hnsw_index = std::move(hnsw_index); + + // either build flat or sq + if (is_fp32_flat_v) { + // build IndexFlat as a refine + auto refine_index = std::make_unique(local_hnsw_index.get()); + + // let refine_index to own everything + refine_index->own_fields = true; + local_hnsw_index.release(); + + // reassign + return refine_index; + } else { + // being IndexScalarQuantizer as a refine + auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); + + // a redundant check + if (!refine_sq_type.has_value()) { + LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << refine_type.value(); + return expected>::Err( + Status::invalid_args, fmt::format("invalid refine type ({})", refine_type.value())); + } + + // create an sq + auto sq_refine = std::make_unique( + local_hnsw_index->storage->d, + refine_sq_type.value(), + local_hnsw_index->storage->metric_type + ); + + auto refine_index = std::make_unique(local_hnsw_index.get(), sq_refine.get()); + + // let refine_index to own everything + refine_index->own_refine_index = true; + refine_index->own_fields = true; + local_hnsw_index.release(); + sq_refine.release(); + + // reassign + return refine_index; + } +} + +} + // template class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { @@ -685,103 +795,208 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { std::unique_ptr final_index; if (hnsw_cfg.refine.value_or(false)) { // yes + auto final_index_cnd = pick_refine_index(hnsw_cfg.refine_type, std::move(hnsw_index)); + if (!final_index_cnd.has_value()) { + return Status::invalid_args; + } - // grab a type of a refine index - bool is_fp32_flat = false; - if (!hnsw_cfg.refine_type.has_value()) { - is_fp32_flat = true; - } else { - // todo: tolower - if (hnsw_cfg.refine_type.value() == "FP32" || hnsw_cfg.refine_type.value() == "FLAT") { - is_fp32_flat = true; - } else { - // parse - auto refine_sq_type = get_sq_quantizer_type(hnsw_cfg.refine_type.value()); - if (!refine_sq_type.has_value()) { - LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << hnsw_cfg.refine_type.value(); - return Status::invalid_args; - } + // assign + final_index = std::move(final_index_cnd.value()); + } else { + // no refine - // recognized one, it is not fp32 flat - is_fp32_flat = false; - } - } + // assign + final_index = std::move(hnsw_index); + } - // either build flat or sq - if (is_fp32_flat) { - // build IndexFlat as a refine - auto refine_index = std::make_unique(hnsw_index.get()); - - // let refine_index to own everything - refine_index->own_fields = true; - hnsw_index.release(); - - // reassign - final_index = std::move(refine_index); - } else { - // being IndexScalarQuantizer as a refine - auto refine_sq_type = get_sq_quantizer_type(hnsw_cfg.refine_type.value()); - - // a redundant check - if (!refine_sq_type.has_value()) { - LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << hnsw_cfg.refine_type.value(); - return Status::invalid_args; - } - - // create an sq - auto sq_refine = std::make_unique( - dim, refine_sq_type.value(), metric.value() - ); - - auto refine_index = std::make_unique(hnsw_index.get(), sq_refine.get()); - - // let refine_index to own everything - refine_index->own_refine_index = true; - refine_index->own_fields = true; - hnsw_index.release(); - sq_refine.release(); - - // reassign - final_index = std::move(refine_index); - } + // train + LOG_KNOWHERE_INFO_ << "Training HNSW Index"; + + final_index->train(rows, (const float*)data); + + // done + index = std::move(final_index); + return Status::success; + } +}; + + +// this index trains PQ and HNSW+FLAT separately, then constructs HNSW+PQ +template +class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { + public: + BaseFaissRegularIndexHNSWPQNode(const int32_t& version, const Object& object) : + BaseFaissRegularIndexHNSWNode(version, object) {} + + std::unique_ptr + CreateConfig() const override { + return std::make_unique(); + } + + std::string + Type() const override { + return knowhere::IndexEnum::INDEX_FAISS_HNSW_PQ; + } + + protected: + std::unique_ptr tmp_index_pq; + + Status TrainInternal(const DataSetPtr dataset, const Config& cfg) override { + // number of rows + auto rows = dataset->GetRows(); + // dimensionality of the data + auto dim = dataset->GetDim(); + // data + auto data = dataset->GetTensor(); + + // config + auto hnsw_cfg = static_cast(cfg); + + auto metric = Str2FaissMetricType(hnsw_cfg.metric_type.value()); + if (!metric.has_value()) { + LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << hnsw_cfg.metric_type.value(); + return Status::invalid_metric_type; + } + + // create an index + const bool is_cosine = IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE); + + // HNSW + PQ index yields BAD recall somewhy. + // Let's build HNSW+FLAT index, then replace FLAT with PQ + + std::unique_ptr hnsw_index; + if (is_cosine) { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value()); + } else { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); + } + + hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); + + // pq + std::unique_ptr pq_index; + if (is_cosine) { + pq_index = std::make_unique(dim, hnsw_cfg.m.value(), hnsw_cfg.nbits.value()); + } else { + pq_index = std::make_unique(dim, hnsw_cfg.m.value(), hnsw_cfg.nbits.value(), metric.value()); + } + + + // should refine be used? + std::unique_ptr final_index; + if (hnsw_cfg.refine.value_or(false)) { + // yes + auto final_index_cnd = pick_refine_index(hnsw_cfg.refine_type, std::move(hnsw_index)); + if (!final_index_cnd.has_value()) { + return Status::invalid_args; + } + + // assign + final_index = std::move(final_index_cnd.value()); } else { // no refine - // reassign + // assign final_index = std::move(hnsw_index); } - // train + // train hnswflat + LOG_KNOWHERE_INFO_ << "Training HNSW Index"; + final_index->train(rows, (const float*)data); + // train pq + LOG_KNOWHERE_INFO_ << "Training PQ Index"; + + pq_index->train(rows, (const float*)data); + pq_index->pq.compute_sdc_table(); + // done index = std::move(final_index); + tmp_index_pq = std::move(pq_index); return Status::success; } -private: - expected - static get_sq_quantizer_type(const std::string& sq_type) { - std::map sq_types = { - {"SQ6", faiss::ScalarQuantizer::QT_6bit}, - {"SQ8", faiss::ScalarQuantizer::QT_8bit}, - {"FP16", faiss::ScalarQuantizer::QT_fp16}, - {"BF16", faiss::ScalarQuantizer::QT_bf16} - }; + Status + AddInternal(const DataSetPtr dataset, const Config&) override { + if (this->index == nullptr) { + LOG_KNOWHERE_ERROR_ << "Can not add data to an empty index."; + return Status::empty_index; + } + + auto data = dataset->GetTensor(); + auto rows = dataset->GetRows(); + try { + // hnsw + LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; + + index->add(rows, reinterpret_cast(data)); + + // pq + LOG_KNOWHERE_INFO_ << "Adding " << rows << " to PQ Index"; + + tmp_index_pq->add(rows, reinterpret_cast(data)); - // todo: tolower - auto itr = sq_types.find(sq_type); - if (itr == sq_types.cend()) { - return expected::Err( - Status::invalid_args, fmt::format("invalid scalar quantizer type ({})", sq_type)); + // we're done. + // throw away flat and replace it with pq + + // check if we have a refine available. + faiss::IndexHNSW* index_hnsw = nullptr; + + faiss::IndexRefine* const index_refine = + dynamic_cast(index.get()); + + if (index_refine != nullptr) { + index_hnsw = dynamic_cast(index_refine->base_index); + } else { + index_hnsw = dynamic_cast(index.get()); + } + + // recreate hnswpq + std::unique_ptr index_hnsw_pq; + + if (index_hnsw->storage->is_cosine) { + index_hnsw_pq = std::make_unique(); + } else { + index_hnsw_pq = std::make_unique(); + } + + // C++ slicing + static_cast(*index_hnsw_pq) = + std::move(static_cast(*index_hnsw)); + + // clear out the storage + delete index_hnsw->storage; + index_hnsw->storage = nullptr; + index_hnsw_pq->storage = nullptr; + + // replace storage + index_hnsw_pq->storage = tmp_index_pq.release(); + + // replace if refine + if (index_refine != nullptr) { + delete index_refine->base_index; + index_refine->base_index = index_hnsw_pq.release(); + } else { + index = std::move(index_hnsw_pq); + } + + } catch (const std::exception& e) { + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); + return Status::faiss_inner_error; } - return itr->second; + return Status::success; } + }; + KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_FLAT, BaseFaissRegularIndexHNSWFlatNode, fp32); + KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_SQ, BaseFaissRegularIndexHNSWSQNode, fp32); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_PQ, BaseFaissRegularIndexHNSWPQNode, fp32); + } // namespace knowhere diff --git a/src/index/hnsw/faiss_hnsw_config.h b/src/index/hnsw/faiss_hnsw_config.h index 637e6280d..c1fcb76f6 100644 --- a/src/index/hnsw/faiss_hnsw_config.h +++ b/src/index/hnsw/faiss_hnsw_config.h @@ -220,6 +220,27 @@ class FaissHnswSqConfig : public FaissHnswConfig { } }; +class FaissHnswPqConfig : public FaissHnswConfig { +public: + // number of subquantizers + CFG_INT m; + // number of bits per subquantizer + CFG_INT nbits; + + KNOHWERE_DECLARE_CONFIG(FaissHnswPqConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(m) + .description("m") + .set_default(32) + .for_train() + .set_range(1, 65536); + KNOWHERE_CONFIG_DECLARE_FIELD(nbits) + .description("nbits") + .set_default(8) + .for_train() + .set_range(1, 16); + } +}; + } // namespace knowhere #endif /* FAISS_HNSW_CONFIG_H */ diff --git a/thirdparty/faiss/faiss/IndexCosine.cpp b/thirdparty/faiss/faiss/IndexCosine.cpp index 7454b4974..fc12c61bb 100644 --- a/thirdparty/faiss/faiss/IndexCosine.cpp +++ b/thirdparty/faiss/faiss/IndexCosine.cpp @@ -346,6 +346,51 @@ DistanceComputer* IndexScalarQuantizerCosine::get_distance_computer() const { } +////////////////////////////////////////////////////////////////////////////////// + +// +IndexPQCosine::IndexPQCosine(int d, size_t M, size_t nbits) : + IndexPQ(d, M, nbits, MetricType::METRIC_INNER_PRODUCT) { + is_cosine = true; +} + +IndexPQCosine::IndexPQCosine() : IndexPQ() { + metric_type = MetricType::METRIC_INNER_PRODUCT; + is_cosine = true; +} + +void IndexPQCosine::add(idx_t n, const float* x) { + FAISS_THROW_IF_NOT(is_trained); + if (n == 0) { + return; + } + + // todo aguzhva: + // it is a tricky situation at this moment, because IndexPQCosine + // contains duplicate norms (one in IndexFlatCodes and one in HasInverseL2Norms). + // Norms in IndexFlatCodes are going to be removed in the future. + IndexPQ::add(n, x); + inverse_norms_storage.add(x, n, d); +} + +void IndexPQCosine::reset() { + IndexPQ::reset(); + inverse_norms_storage.reset(); +} + +const float* IndexPQCosine::get_inverse_l2_norms() const { + return inverse_norms_storage.inverse_l2_norms.data(); +} + +DistanceComputer* IndexPQCosine::get_distance_computer() const { + return new WithCosineNormDistanceComputer( + this->get_inverse_l2_norms(), + this->d, + std::unique_ptr(IndexPQ::get_FlatCodesDistanceComputer()) + ); +} + + ////////////////////////////////////////////////////////////////////////////////// // @@ -360,6 +405,7 @@ IndexHNSWFlatCosine::IndexHNSWFlatCosine(int d, int M) : is_trained = true; } + ////////////////////////////////////////////////////////////////////////////////// // @@ -376,5 +422,20 @@ IndexHNSWSQCosine::IndexHNSWSQCosine( } +// +IndexHNSWPQCosine::IndexHNSWPQCosine() = default; + +IndexHNSWPQCosine::IndexHNSWPQCosine(int d, size_t pq_M, int M, size_t pq_nbits) : + IndexHNSW(new IndexPQCosine(d, pq_M, pq_nbits), M) +{ + own_fields = true; +} + +void IndexHNSWPQCosine::train(idx_t n, const float* x) { + IndexHNSW::train(n, x); + (dynamic_cast(storage))->pq.compute_sdc_table(); +} + + } diff --git a/thirdparty/faiss/faiss/IndexCosine.h b/thirdparty/faiss/faiss/IndexCosine.h index cd283f795..3bb3eae7a 100644 --- a/thirdparty/faiss/faiss/IndexCosine.h +++ b/thirdparty/faiss/faiss/IndexCosine.h @@ -13,8 +13,13 @@ #pragma once +#include +#include + #include #include +#include +#include #include @@ -117,6 +122,23 @@ struct IndexScalarQuantizerCosine : IndexScalarQuantizer, HasInverseL2Norms { const float* get_inverse_l2_norms() const override; }; +// +struct IndexPQCosine : IndexPQ, HasInverseL2Norms { + L2NormsStorage inverse_norms_storage; + + IndexPQCosine(int d, size_t M, size_t nbits); + + IndexPQCosine(); + + void add(idx_t n, const float* x) override; + void reset() override; + + DistanceComputer* get_distance_computer() const override; + + const float* get_inverse_l2_norms() const override; +}; + + // struct IndexHNSWFlatCosine : IndexHNSW { IndexHNSWFlatCosine(); @@ -132,5 +154,17 @@ struct IndexHNSWSQCosine : IndexHNSW { int M); }; +// +struct IndexHNSWPQCosine : IndexHNSW { + IndexHNSWPQCosine(); + IndexHNSWPQCosine( + int d, + size_t pq_M, + int M, + size_t pq_nbits); + + void train(idx_t n, const float* x) override; +}; + } diff --git a/thirdparty/faiss/faiss/impl/index_read.cpp b/thirdparty/faiss/faiss/impl/index_read.cpp index b980f2859..9889342f5 100644 --- a/thirdparty/faiss/faiss/impl/index_read.cpp +++ b/thirdparty/faiss/faiss/impl/index_read.cpp @@ -752,6 +752,23 @@ Index* read_index(IOReader* f, int io_flags) { FAISS_THROW_IF_NOT( idxl->codes.size() == idxl->ntotal * idxl->code_size); idx = idxl; + } else if (h == fourcc("IxP7")) { + IndexPQCosine* idxp = new IndexPQCosine(); + read_index_header(idxp, f); + read_ProductQuantizer(&idxp->pq, f); + idxp->code_size = idxp->pq.code_size; + READVECTOR(idxp->codes); + READ1(idxp->search_type); + READ1(idxp->encode_signs); + READ1(idxp->polysemous_ht); + // read inverse norms + READVECTOR(idxp->inverse_norms_storage.inverse_l2_norms); + + if (!(io_flags & IO_FLAG_PQ_SKIP_SDC_TABLE)) { + idxp->pq.compute_sdc_table (); + } + + idx = idxp; } else if ( h == fourcc("IxPQ") || h == fourcc("IxPo") || h == fourcc("IxPq")) { // IxPQ and IxPo were merged into the same IndexPQ object @@ -1166,7 +1183,7 @@ Index* read_index(IOReader* f, int io_flags) { } else if ( h == fourcc("IHNf") || h == fourcc("IHNp") || h == fourcc("IHNs") || h == fourcc("IHN2") || h == fourcc("IHNc") || h == fourcc("IHN9") || - h == fourcc("IHN8")) { + h == fourcc("IHN8") || h == fourcc("IHN7")) { IndexHNSW* idxhnsw = nullptr; if (h == fourcc("IHNf")) idxhnsw = new IndexHNSWFlat(); @@ -1182,6 +1199,8 @@ Index* read_index(IOReader* f, int io_flags) { idxhnsw = new IndexHNSWFlatCosine(); if (h == fourcc("IHN8")) idxhnsw = new IndexHNSWSQCosine(); + if (h == fourcc("IHN7")) + idxhnsw = new IndexHNSWPQCosine(); read_index_header(idxhnsw, f); if (h == fourcc("IHNc")) { READ1(idxhnsw->keep_max_size_level0); diff --git a/thirdparty/faiss/faiss/impl/index_write.cpp b/thirdparty/faiss/faiss/impl/index_write.cpp index b465e4bce..d44a223ea 100644 --- a/thirdparty/faiss/faiss/impl/index_write.cpp +++ b/thirdparty/faiss/faiss/impl/index_write.cpp @@ -591,6 +591,18 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { WRITE1(code_size_i); write_VectorTransform(&idxl->rrot, f); WRITEVECTOR(idxl->codes); + } else if (const IndexPQCosine* idxp = dynamic_cast(idx)) { + uint32_t h = fourcc("IxP7"); + WRITE1(h); + write_index_header(idx, f); + write_ProductQuantizer(&idxp->pq, f); + WRITEVECTOR(idxp->codes); + // search params -- maybe not useful to store? + WRITE1(idxp->search_type); + WRITE1(idxp->encode_signs); + WRITE1(idxp->polysemous_ht); + // inverse norms + WRITEVECTOR(idxp->inverse_norms_storage.inverse_l2_norms); } else if (const IndexPQ* idxp = dynamic_cast(idx)) { uint32_t h = fourcc("IxPq"); WRITE1(h); @@ -969,6 +981,7 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { : dynamic_cast(idx) ? fourcc("IHNc") : dynamic_cast(idx) ? fourcc("IHN9") : dynamic_cast(idx) ? fourcc("IHN8") + : dynamic_cast(idx) ? fourcc("IHN7") : 0; FAISS_THROW_IF_NOT(h != 0); WRITE1(h);