Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

add ivf_flat_cc index #824

Merged
merged 1 commit into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ constexpr const char* INDEX_FAISS_BIN_IVFFLAT = "BIN_IVF_FLAT";

constexpr const char* INDEX_FAISS_IDMAP = "FLAT";
constexpr const char* INDEX_FAISS_IVFFLAT = "IVF_FLAT";
constexpr const char* INDEX_FAISS_IVFFLAT_CC = "IVF_FLAT_CC";
constexpr const char* INDEX_FAISS_IVFPQ = "IVF_PQ";
constexpr const char* INDEX_FAISS_IVFSQ8 = "IVF_SQ8";

Expand All @@ -43,6 +44,7 @@ constexpr const char* INDEX_DISKANN = "DISKANN";
} // namespace IndexEnum

namespace meta {
constexpr const char* INDEX_TYPE = "index_type";
constexpr const char* METRIC_TYPE = "metric_type";
constexpr const char* DIM = "dim";
constexpr const char* TENSOR = "tensor";
Expand All @@ -56,6 +58,7 @@ constexpr const char* RANGE_FILTER = "range_filter";
constexpr const char* INPUT_IDS = "input_ids";
constexpr const char* OUTPUT_TENSOR = "output_tensor";
constexpr const char* DEVICE_ID = "gpu_id";
constexpr const char* NUM_BUILD_THREAD = "num_build_thread";
constexpr const char* TRACE_VISIT = "trace_visit";
constexpr const char* JSON_INFO = "json_info";
constexpr const char* JSON_ID_SET = "json_id_set";
Expand All @@ -67,6 +70,7 @@ constexpr const char* NPROBE = "nprobe";
constexpr const char* NLIST = "nlist";
constexpr const char* NBITS = "nbits"; // PQ/SQ
constexpr const char* M = "m"; // PQ param for IVFPQ
constexpr const char* SSIZE = "ssize";
// HNSW Params
constexpr const char* EFCONSTRUCTION = "efConstruction";
constexpr const char* HNSW_M = "M";
Expand Down
2 changes: 1 addition & 1 deletion include/knowhere/comp/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class ThreadPool {
int omp_before;

public:
explicit ScopedOmpSetter(int num_threads = 1) : omp_before(omp_get_num_threads()) {
explicit ScopedOmpSetter(int num_threads = 1) : omp_before(omp_get_max_threads()) {
omp_set_num_threads(num_threads);
}
~ScopedOmpSetter() {
Expand Down
16 changes: 16 additions & 0 deletions include/knowhere/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
#ifndef CONFIG_H
#define CONFIG_H

#include <omp.h>

#include <iostream>
#include <list>
#include <optional>
#include <sstream>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <variant>
Expand Down Expand Up @@ -392,6 +395,7 @@ class BaseConfig : public Config {
public:
CFG_STRING metric_type;
CFG_INT k;
CFG_INT num_build_thread;
CFG_FLOAT radius;
CFG_FLOAT range_filter;
CFG_BOOL trace_visit;
Expand All @@ -402,6 +406,10 @@ class BaseConfig : public Config {
.description("search for top k similar vector.")
.set_range(1, std::numeric_limits<CFG_INT>::max())
.for_search();
KNOWHERE_CONFIG_DECLARE_FIELD(num_build_thread)
.set_default(-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-1 is a valid value ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need a default value to indicate user has not set this value, and '-1' gives this semantic.

.description("index thread limit for build.")
.for_search();
KNOWHERE_CONFIG_DECLARE_FIELD(radius)
.set_default(0.0)
.description("radius for range search")
Expand All @@ -416,6 +424,14 @@ class BaseConfig : public Config {
.for_search()
.for_range_search();
}

int
get_build_thread_num() const {
if (num_build_thread > 0) {
return num_build_thread;
}
return omp_get_max_threads();
}
};

struct LoadConfig {
Expand Down
58 changes: 56 additions & 2 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ template <typename T>
class IvfIndexNode : public IndexNode {
public:
IvfIndexNode(const Object& object) : index_(nullptr) {
static_assert(std::is_same<T, faiss::IndexIVFFlat>::value || std::is_same<T, faiss::IndexIVFPQ>::value ||
static_assert(std::is_same<T, faiss::IndexIVFFlat>::value || std::is_same<T, faiss::IndexIVFFlatCC>::value ||
std::is_same<T, faiss::IndexIVFPQ>::value ||
std::is_same<T, faiss::IndexIVFScalarQuantizer>::value ||
std::is_same<T, faiss::IndexBinaryIVF>::value,
"not support");
Expand All @@ -66,6 +67,9 @@ class IvfIndexNode : public IndexNode {
if constexpr (std::is_same<faiss::IndexIVFFlat, T>::value) {
return true;
}
if constexpr (std::is_same<faiss::IndexIVFFlatCC, T>::value) {
return false;
}
if constexpr (std::is_same<faiss::IndexIVFPQ, T>::value) {
return false;
}
Expand All @@ -91,6 +95,9 @@ class IvfIndexNode : public IndexNode {
if constexpr (std::is_same<faiss::IndexIVFFlat, T>::value) {
return std::make_unique<IvfFlatConfig>();
}
if constexpr (std::is_same<faiss::IndexIVFFlatCC, T>::value) {
return std::make_unique<IvfFlatCcConfig>();
}
if constexpr (std::is_same<faiss::IndexIVFPQ, T>::value) {
return std::make_unique<IvfPqConfig>();
}
Expand Down Expand Up @@ -119,6 +126,12 @@ class IvfIndexNode : public IndexNode {
auto code_size = index_->code_size;
return (nb * code_size + nb * sizeof(int64_t) + nlist * code_size);
}
if constexpr (std::is_same<T, faiss::IndexIVFFlatCC>::value) {
auto nb = index_->invlists->compute_ntotal();
auto nlist = index_->nlist;
auto code_size = index_->code_size;
return (nb * code_size + nb * sizeof(int64_t) + nlist * code_size);
}
if constexpr (std::is_same<T, faiss::IndexIVFPQ>::value) {
auto nb = index_->invlists->compute_ntotal();
auto code_size = index_->code_size;
Expand Down Expand Up @@ -156,6 +169,9 @@ class IvfIndexNode : public IndexNode {
if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
return knowhere::IndexEnum::INDEX_FAISS_IVFFLAT;
}
if constexpr (std::is_same<T, faiss::IndexIVFFlatCC>::value) {
return knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC;
}
if constexpr (std::is_same<T, faiss::IndexIVFPQ>::value) {
return knowhere::IndexEnum::INDEX_FAISS_IVFPQ;
}
Expand All @@ -179,6 +195,9 @@ namespace knowhere {
template <typename T>
Status
IvfIndexNode<T>::Build(const DataSet& dataset, const Config& cfg) {
const BaseConfig& base_cfg = static_cast<const IvfConfig&>(cfg);
auto build_thread_num = base_cfg.get_build_thread_num();
ThreadPool::ScopedOmpSetter setter(build_thread_num);
auto err = Train(dataset, cfg);
if (err != Status::success) {
return err;
Expand Down Expand Up @@ -223,6 +242,8 @@ template <typename T>
Status
IvfIndexNode<T>::Train(const DataSet& dataset, const Config& cfg) {
const BaseConfig& base_cfg = static_cast<const IvfConfig&>(cfg);
auto build_thread_num = base_cfg.get_build_thread_num();
ThreadPool::ScopedOmpSetter setter(build_thread_num);

// do normalize for COSINE metric type
if (IsMetricType(base_cfg.metric_type, knowhere::metric::COSINE)) {
Expand All @@ -248,6 +269,14 @@ IvfIndexNode<T>::Train(const DataSet& dataset, const Config& cfg) {
index->own_fields = true;
index->train(rows, (const float*)data);
}
if constexpr (std::is_same<faiss::IndexIVFFlatCC, T>::value) {
const IvfFlatCcConfig& ivf_flat_cc_cfg = static_cast<const IvfFlatCcConfig&>(cfg);
auto nlist = MatchNlist(rows, ivf_flat_cc_cfg.nlist);
qzr = new (std::nothrow) typename QuantizerT<T>::type(dim, metric.value());
index = std::make_unique<faiss::IndexIVFFlatCC>(qzr, dim, nlist, ivf_flat_cc_cfg.ssize, metric.value());
index->own_fields = true;
index->train(rows, (const float*)data);
}
if constexpr (std::is_same<faiss::IndexIVFPQ, T>::value) {
const IvfPqConfig& ivf_pq_cfg = static_cast<const IvfPqConfig&>(cfg);
auto nlist = MatchNlist(rows, ivf_pq_cfg.nlist);
Expand Down Expand Up @@ -288,12 +317,15 @@ IvfIndexNode<T>::Train(const DataSet& dataset, const Config& cfg) {

template <typename T>
Status
IvfIndexNode<T>::Add(const DataSet& dataset, const Config&) {
IvfIndexNode<T>::Add(const DataSet& dataset, const Config& cfg) {
if (!this->index_) {
return Status::empty_index;
}
auto data = dataset.GetTensor();
auto rows = dataset.GetRows();
const BaseConfig& base_cfg = static_cast<const IvfConfig&>(cfg);
auto build_thread_num = base_cfg.get_build_thread_num();
ThreadPool::ScopedOmpSetter setter(build_thread_num);
try {
if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
index_->add_without_codes(rows, (const float*)data);
Expand Down Expand Up @@ -537,6 +569,22 @@ IvfIndexNode<T>::GetVectorByIds(const DataSet& dataset, const Config& cfg) const
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
return unexpected(Status::faiss_inner_error);
}
} else if constexpr (std::is_same<T, faiss::IndexIVFFlatCC>::value) {
float* data = nullptr;
try {
data = new float[dim * rows];
index_->make_direct_map(true);
for (int64_t i = 0; i < rows; i++) {
int64_t id = ids[i];
assert(id >= 0 && id < index_->ntotal);
index_->reconstruct(id, data + i * dim);
}
return GenResultDataSet(rows, dim, data);
} catch (const std::exception& e) {
std::unique_ptr<float> auto_del(data);
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
return unexpected(Status::faiss_inner_error);
}
} else {
return unexpected(Status::not_implemented);
}
Expand Down Expand Up @@ -701,6 +749,12 @@ KNOWHERE_REGISTER_GLOBAL(IVFFLAT,
[](const Object& object) { return Index<IvfIndexNode<faiss::IndexIVFFlat>>::Create(object); });
KNOWHERE_REGISTER_GLOBAL(IVF_FLAT,
[](const Object& object) { return Index<IvfIndexNode<faiss::IndexIVFFlat>>::Create(object); });
KNOWHERE_REGISTER_GLOBAL(IVFFLATCC, [](const Object& object) {
foxspy marked this conversation as resolved.
Show resolved Hide resolved
return Index<IvfIndexNode<faiss::IndexIVFFlatCC>>::Create(object);
});
KNOWHERE_REGISTER_GLOBAL(IVF_FLAT_CC, [](const Object& object) {
return Index<IvfIndexNode<faiss::IndexIVFFlatCC>>::Create(object);
});
KNOWHERE_REGISTER_GLOBAL(IVFPQ,
[](const Object& object) { return Index<IvfIndexNode<faiss::IndexIVFPQ>>::Create(object); });
KNOWHERE_REGISTER_GLOBAL(IVF_PQ,
Expand Down
12 changes: 12 additions & 0 deletions src/index/ivf/ivf_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ class IvfConfig : public BaseConfig {

class IvfFlatConfig : public IvfConfig {};

class IvfFlatCcConfig : public IvfFlatConfig {
public:
int ssize;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better to name it to "segment_size"

KNOHWERE_DECLARE_CONFIG(IvfFlatCcConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(ssize)
.description("segment size")
.set_default(48)
.for_train()
.set_range(32, 2048);
}
};

class IvfPqConfig : public IvfConfig {
public:
int m;
Expand Down
11 changes: 11 additions & 0 deletions tests/python/test_index_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
"nprobe": 1024,
},
),
(
"IVFFLATCC",
{
"dim": 256,
"k": 15,
"metric_type": "L2",
"n_list": 1024,
"nprobe": 1024,
"ssize" : 48
},
),
(
"IVFSQ",
{
Expand Down
11 changes: 11 additions & 0 deletions tests/python/test_index_with_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
"nprobe": 1024,
},
),
(
"IVFFLATCC",
{
"dim": 256,
"k": 15,
"metric_type": "L2",
"nlist": 1024,
"nprobe": 1024,
"ssize": 48
},
),
# (
# "IVFSQ",
# {
Expand Down
11 changes: 11 additions & 0 deletions tests/python/test_index_with_sift.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ def download_sift():
"nprobe": 128,
},
),
(
"IVFFLATCC",
{
"dim": 128,
"k": 100,
"metric_type": "L2",
"n_list": 1024,
"nprobe": 128,
"ssize": 48
},
),
(
"IVFSQ",
{
Expand Down
7 changes: 7 additions & 0 deletions tests/ut/test_get_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ TEST_CASE("Test Get Vector By Ids", "[GetVectorByIds]") {
return json;
};

auto ivfflatcc_gen = [&ivfflat_gen]() {
knowhere::Json json = ivfflat_gen();
json[knowhere::indexparam::SSIZE] = 48;
return json;
};

auto bin_ivfflat_gen = [&base_bin_gen]() {
knowhere::Json json = base_bin_gen();
json[knowhere::indexparam::NLIST] = 16;
Expand Down Expand Up @@ -118,6 +124,7 @@ TEST_CASE("Test Get Vector By Ids", "[GetVectorByIds]") {
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IDMAP, flat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivfflat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create(name);
Expand Down
Loading