Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
add ivf_flat_cc index (#824)
Browse files Browse the repository at this point in the history
Signed-off-by: xianliang <xianliang.li@zilliz.com>
  • Loading branch information
foxspy authored Apr 21, 2023
1 parent 5cd495a commit 717a7b0
Show file tree
Hide file tree
Showing 19 changed files with 993 additions and 43 deletions.
4 changes: 4 additions & 0 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ constexpr const char* INDEX_FAISS_BIN_IVFFLAT = "BIN_IVF_FLAT";

constexpr const char* INDEX_FAISS_IDMAP = "FLAT";
constexpr const char* INDEX_FAISS_IVFFLAT = "IVF_FLAT";
constexpr const char* INDEX_FAISS_IVFFLAT_CC = "IVF_FLAT_CC";
constexpr const char* INDEX_FAISS_IVFPQ = "IVF_PQ";
constexpr const char* INDEX_FAISS_IVFSQ8 = "IVF_SQ8";

Expand All @@ -43,6 +44,7 @@ constexpr const char* INDEX_DISKANN = "DISKANN";
} // namespace IndexEnum

namespace meta {
constexpr const char* INDEX_TYPE = "index_type";
constexpr const char* METRIC_TYPE = "metric_type";
constexpr const char* DIM = "dim";
constexpr const char* TENSOR = "tensor";
Expand All @@ -56,6 +58,7 @@ constexpr const char* RANGE_FILTER = "range_filter";
constexpr const char* INPUT_IDS = "input_ids";
constexpr const char* OUTPUT_TENSOR = "output_tensor";
constexpr const char* DEVICE_ID = "gpu_id";
constexpr const char* NUM_BUILD_THREAD = "num_build_thread";
constexpr const char* TRACE_VISIT = "trace_visit";
constexpr const char* JSON_INFO = "json_info";
constexpr const char* JSON_ID_SET = "json_id_set";
Expand All @@ -67,6 +70,7 @@ constexpr const char* NPROBE = "nprobe";
constexpr const char* NLIST = "nlist";
constexpr const char* NBITS = "nbits"; // PQ/SQ
constexpr const char* M = "m"; // PQ param for IVFPQ
constexpr const char* SSIZE = "ssize";
// HNSW Params
constexpr const char* EFCONSTRUCTION = "efConstruction";
constexpr const char* HNSW_M = "M";
Expand Down
2 changes: 1 addition & 1 deletion include/knowhere/comp/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class ThreadPool {
int omp_before;

public:
explicit ScopedOmpSetter(int num_threads = 1) : omp_before(omp_get_num_threads()) {
explicit ScopedOmpSetter(int num_threads = 1) : omp_before(omp_get_max_threads()) {
omp_set_num_threads(num_threads);
}
~ScopedOmpSetter() {
Expand Down
16 changes: 16 additions & 0 deletions include/knowhere/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
#ifndef CONFIG_H
#define CONFIG_H

#include <omp.h>

#include <iostream>
#include <list>
#include <optional>
#include <sstream>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <variant>
Expand Down Expand Up @@ -392,6 +395,7 @@ class BaseConfig : public Config {
public:
CFG_STRING metric_type;
CFG_INT k;
CFG_INT num_build_thread;
CFG_FLOAT radius;
CFG_FLOAT range_filter;
CFG_BOOL trace_visit;
Expand All @@ -402,6 +406,10 @@ class BaseConfig : public Config {
.description("search for top k similar vector.")
.set_range(1, std::numeric_limits<CFG_INT>::max())
.for_search();
KNOWHERE_CONFIG_DECLARE_FIELD(num_build_thread)
.set_default(-1)
.description("index thread limit for build.")
.for_search();
KNOWHERE_CONFIG_DECLARE_FIELD(radius)
.set_default(0.0)
.description("radius for range search")
Expand All @@ -416,6 +424,14 @@ class BaseConfig : public Config {
.for_search()
.for_range_search();
}

int
get_build_thread_num() const {
if (num_build_thread > 0) {
return num_build_thread;
}
return omp_get_max_threads();
}
};

struct LoadConfig {
Expand Down
58 changes: 56 additions & 2 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ template <typename T>
class IvfIndexNode : public IndexNode {
public:
IvfIndexNode(const Object& object) : index_(nullptr) {
static_assert(std::is_same<T, faiss::IndexIVFFlat>::value || std::is_same<T, faiss::IndexIVFPQ>::value ||
static_assert(std::is_same<T, faiss::IndexIVFFlat>::value || std::is_same<T, faiss::IndexIVFFlatCC>::value ||
std::is_same<T, faiss::IndexIVFPQ>::value ||
std::is_same<T, faiss::IndexIVFScalarQuantizer>::value ||
std::is_same<T, faiss::IndexBinaryIVF>::value,
"not support");
Expand All @@ -66,6 +67,9 @@ class IvfIndexNode : public IndexNode {
if constexpr (std::is_same<faiss::IndexIVFFlat, T>::value) {
return true;
}
if constexpr (std::is_same<faiss::IndexIVFFlatCC, T>::value) {
return false;
}
if constexpr (std::is_same<faiss::IndexIVFPQ, T>::value) {
return false;
}
Expand All @@ -91,6 +95,9 @@ class IvfIndexNode : public IndexNode {
if constexpr (std::is_same<faiss::IndexIVFFlat, T>::value) {
return std::make_unique<IvfFlatConfig>();
}
if constexpr (std::is_same<faiss::IndexIVFFlatCC, T>::value) {
return std::make_unique<IvfFlatCcConfig>();
}
if constexpr (std::is_same<faiss::IndexIVFPQ, T>::value) {
return std::make_unique<IvfPqConfig>();
}
Expand Down Expand Up @@ -119,6 +126,12 @@ class IvfIndexNode : public IndexNode {
auto code_size = index_->code_size;
return (nb * code_size + nb * sizeof(int64_t) + nlist * code_size);
}
if constexpr (std::is_same<T, faiss::IndexIVFFlatCC>::value) {
auto nb = index_->invlists->compute_ntotal();
auto nlist = index_->nlist;
auto code_size = index_->code_size;
return (nb * code_size + nb * sizeof(int64_t) + nlist * code_size);
}
if constexpr (std::is_same<T, faiss::IndexIVFPQ>::value) {
auto nb = index_->invlists->compute_ntotal();
auto code_size = index_->code_size;
Expand Down Expand Up @@ -156,6 +169,9 @@ class IvfIndexNode : public IndexNode {
if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
return knowhere::IndexEnum::INDEX_FAISS_IVFFLAT;
}
if constexpr (std::is_same<T, faiss::IndexIVFFlatCC>::value) {
return knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC;
}
if constexpr (std::is_same<T, faiss::IndexIVFPQ>::value) {
return knowhere::IndexEnum::INDEX_FAISS_IVFPQ;
}
Expand All @@ -179,6 +195,9 @@ namespace knowhere {
template <typename T>
Status
IvfIndexNode<T>::Build(const DataSet& dataset, const Config& cfg) {
const BaseConfig& base_cfg = static_cast<const IvfConfig&>(cfg);
auto build_thread_num = base_cfg.get_build_thread_num();
ThreadPool::ScopedOmpSetter setter(build_thread_num);
auto err = Train(dataset, cfg);
if (err != Status::success) {
return err;
Expand Down Expand Up @@ -223,6 +242,8 @@ template <typename T>
Status
IvfIndexNode<T>::Train(const DataSet& dataset, const Config& cfg) {
const BaseConfig& base_cfg = static_cast<const IvfConfig&>(cfg);
auto build_thread_num = base_cfg.get_build_thread_num();
ThreadPool::ScopedOmpSetter setter(build_thread_num);

// do normalize for COSINE metric type
if (IsMetricType(base_cfg.metric_type, knowhere::metric::COSINE)) {
Expand All @@ -248,6 +269,14 @@ IvfIndexNode<T>::Train(const DataSet& dataset, const Config& cfg) {
index->own_fields = true;
index->train(rows, (const float*)data);
}
if constexpr (std::is_same<faiss::IndexIVFFlatCC, T>::value) {
const IvfFlatCcConfig& ivf_flat_cc_cfg = static_cast<const IvfFlatCcConfig&>(cfg);
auto nlist = MatchNlist(rows, ivf_flat_cc_cfg.nlist);
qzr = new (std::nothrow) typename QuantizerT<T>::type(dim, metric.value());
index = std::make_unique<faiss::IndexIVFFlatCC>(qzr, dim, nlist, ivf_flat_cc_cfg.ssize, metric.value());
index->own_fields = true;
index->train(rows, (const float*)data);
}
if constexpr (std::is_same<faiss::IndexIVFPQ, T>::value) {
const IvfPqConfig& ivf_pq_cfg = static_cast<const IvfPqConfig&>(cfg);
auto nlist = MatchNlist(rows, ivf_pq_cfg.nlist);
Expand Down Expand Up @@ -288,12 +317,15 @@ IvfIndexNode<T>::Train(const DataSet& dataset, const Config& cfg) {

template <typename T>
Status
IvfIndexNode<T>::Add(const DataSet& dataset, const Config&) {
IvfIndexNode<T>::Add(const DataSet& dataset, const Config& cfg) {
if (!this->index_) {
return Status::empty_index;
}
auto data = dataset.GetTensor();
auto rows = dataset.GetRows();
const BaseConfig& base_cfg = static_cast<const IvfConfig&>(cfg);
auto build_thread_num = base_cfg.get_build_thread_num();
ThreadPool::ScopedOmpSetter setter(build_thread_num);
try {
if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
index_->add_without_codes(rows, (const float*)data);
Expand Down Expand Up @@ -537,6 +569,22 @@ IvfIndexNode<T>::GetVectorByIds(const DataSet& dataset, const Config& cfg) const
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
return unexpected(Status::faiss_inner_error);
}
} else if constexpr (std::is_same<T, faiss::IndexIVFFlatCC>::value) {
float* data = nullptr;
try {
data = new float[dim * rows];
index_->make_direct_map(true);
for (int64_t i = 0; i < rows; i++) {
int64_t id = ids[i];
assert(id >= 0 && id < index_->ntotal);
index_->reconstruct(id, data + i * dim);
}
return GenResultDataSet(rows, dim, data);
} catch (const std::exception& e) {
std::unique_ptr<float> auto_del(data);
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
return unexpected(Status::faiss_inner_error);
}
} else {
return unexpected(Status::not_implemented);
}
Expand Down Expand Up @@ -701,6 +749,12 @@ KNOWHERE_REGISTER_GLOBAL(IVFFLAT,
[](const Object& object) { return Index<IvfIndexNode<faiss::IndexIVFFlat>>::Create(object); });
KNOWHERE_REGISTER_GLOBAL(IVF_FLAT,
[](const Object& object) { return Index<IvfIndexNode<faiss::IndexIVFFlat>>::Create(object); });
KNOWHERE_REGISTER_GLOBAL(IVFFLATCC, [](const Object& object) {
return Index<IvfIndexNode<faiss::IndexIVFFlatCC>>::Create(object);
});
KNOWHERE_REGISTER_GLOBAL(IVF_FLAT_CC, [](const Object& object) {
return Index<IvfIndexNode<faiss::IndexIVFFlatCC>>::Create(object);
});
KNOWHERE_REGISTER_GLOBAL(IVFPQ,
[](const Object& object) { return Index<IvfIndexNode<faiss::IndexIVFPQ>>::Create(object); });
KNOWHERE_REGISTER_GLOBAL(IVF_PQ,
Expand Down
12 changes: 12 additions & 0 deletions src/index/ivf/ivf_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ class IvfConfig : public BaseConfig {

class IvfFlatConfig : public IvfConfig {};

class IvfFlatCcConfig : public IvfFlatConfig {
public:
int ssize;
KNOHWERE_DECLARE_CONFIG(IvfFlatCcConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(ssize)
.description("segment size")
.set_default(48)
.for_train()
.set_range(32, 2048);
}
};

class IvfPqConfig : public IvfConfig {
public:
int m;
Expand Down
11 changes: 11 additions & 0 deletions tests/python/test_index_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
"nprobe": 1024,
},
),
(
"IVFFLATCC",
{
"dim": 256,
"k": 15,
"metric_type": "L2",
"n_list": 1024,
"nprobe": 1024,
"ssize" : 48
},
),
(
"IVFSQ",
{
Expand Down
11 changes: 11 additions & 0 deletions tests/python/test_index_with_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
"nprobe": 1024,
},
),
(
"IVFFLATCC",
{
"dim": 256,
"k": 15,
"metric_type": "L2",
"nlist": 1024,
"nprobe": 1024,
"ssize": 48
},
),
# (
# "IVFSQ",
# {
Expand Down
11 changes: 11 additions & 0 deletions tests/python/test_index_with_sift.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ def download_sift():
"nprobe": 128,
},
),
(
"IVFFLATCC",
{
"dim": 128,
"k": 100,
"metric_type": "L2",
"n_list": 1024,
"nprobe": 128,
"ssize": 48
},
),
(
"IVFSQ",
{
Expand Down
7 changes: 7 additions & 0 deletions tests/ut/test_get_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ TEST_CASE("Test Get Vector By Ids", "[GetVectorByIds]") {
return json;
};

auto ivfflatcc_gen = [&ivfflat_gen]() {
knowhere::Json json = ivfflat_gen();
json[knowhere::indexparam::SSIZE] = 48;
return json;
};

auto bin_ivfflat_gen = [&base_bin_gen]() {
knowhere::Json json = base_bin_gen();
json[knowhere::indexparam::NLIST] = 16;
Expand Down Expand Up @@ -118,6 +124,7 @@ TEST_CASE("Test Get Vector By Ids", "[GetVectorByIds]") {
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IDMAP, flat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivfflat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create(name);
Expand Down
Loading

0 comments on commit 717a7b0

Please sign in to comment.