Skip to content

Commit

Permalink
Remove IVF_FLAT_NM (#39)
Browse files Browse the repository at this point in the history
Signed-off-by: Yudong Cai <yudong.cai@zilliz.com>
  • Loading branch information
cydrain authored Sep 7, 2023
1 parent 95f7078 commit dbd7756
Show file tree
Hide file tree
Showing 31 changed files with 131 additions and 1,535 deletions.
2 changes: 1 addition & 1 deletion benchmark/hdf5/benchmark_binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class Benchmark_binary : public Benchmark_knowhere, public ::testing::Test {
parse_ann_test_name();
load_hdf5_data<true>();

assert(metric_str_ == METRIC_HAM_STR || metric_str_ == METRIC_JAC_STR || metric_str_ == METRIC_TAN_STR);
assert(metric_str_ == METRIC_HAM_STR || metric_str_ == METRIC_JAC_STR);
metric_type_ = (metric_str_ == METRIC_HAM_STR) ? knowhere::metric::HAMMING : knowhere::metric::JACCARD;
cfg_[knowhere::meta::METRIC_TYPE] = metric_type_;
knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2);
Expand Down
2 changes: 1 addition & 1 deletion benchmark/hdf5/benchmark_binary_range.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class Benchmark_binary_range : public Benchmark_knowhere, public ::testing::Test
load_hdf5_data_range<true>();
#endif

assert(metric_str_ == METRIC_HAM_STR || metric_str_ == METRIC_JAC_STR || metric_str_ == METRIC_TAN_STR);
assert(metric_str_ == METRIC_HAM_STR || metric_str_ == METRIC_JAC_STR);
metric_type_ = (metric_str_ == METRIC_HAM_STR) ? knowhere::metric::HAMMING : knowhere::metric::JACCARD;
cfg_[knowhere::meta::METRIC_TYPE] = metric_type_;
cfg_[knowhere::meta::RADIUS] = *gt_radius_;
Expand Down
2 changes: 1 addition & 1 deletion benchmark/hdf5/benchmark_float_range_bitset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Benchmark_float_range_bitset : public Benchmark_knowhere, public ::testing
auto conf = cfg;
auto radius = conf[knowhere::meta::RADIUS].get<float>();

printf("\n[%0.3f s] %s | %s | radius=.3f\n", get_time_diff(), ann_test_name_.c_str(), index_type_.c_str(),
printf("\n[%0.3f s] %s | %s | radius=%.3f\n", get_time_diff(), ann_test_name_.c_str(), index_type_.c_str(),
radius);
printf("================================================================================\n");
for (auto per : PERCENTs_) {
Expand Down
2 changes: 1 addition & 1 deletion benchmark/hdf5/benchmark_knowhere.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class Benchmark_knowhere : public Benchmark_hdf5 {

// IVFFLAT_NM should load raw data
knowhere::BinaryPtr bin = std::make_shared<knowhere::Binary>();
bin->data = std::shared_ptr<uint8_t[]>((uint8_t*)xb_, [&](uint8_t*) {});
bin->data = std::shared_ptr<uint8_t[]>((uint8_t*)xb_);
bin->size = dim_ * nb_ * sizeof(float);
binary_set.Append("RAW_DATA", bin);

Expand Down
4 changes: 4 additions & 0 deletions include/knowhere/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include <vector>

#include "knowhere/binaryset.h"
#include "knowhere/dataset.h"

namespace knowhere {
Expand Down Expand Up @@ -65,4 +66,7 @@ round_down(const T value, const T align) {
return value / align * align;
}

void
ConvertIVFFlatIfNeeded(const BinarySet& binset, const uint8_t* raw_data, const size_t raw_size);

} // namespace knowhere
40 changes: 40 additions & 0 deletions src/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
#include <cmath>
#include <cstdint>

#include "faiss/IndexIVFFlat.h"
#include "faiss/impl/FaissException.h"
#include "faiss/index_io.h"
#include "io/memory_io.h"
#include "knowhere/log.h"
#include "simd/hook.h"

Expand Down Expand Up @@ -66,4 +70,40 @@ CopyAndNormalizeFloatVec(const float* x, int32_t dim) {
return x_norm;
}

void
ConvertIVFFlatIfNeeded(const BinarySet& binset, const uint8_t* raw_data, const size_t raw_size) {
std::vector<std::string> names = {"IVF", // compatible with knowhere-1.x
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT};
auto binary = binset.GetByNames(names);
if (binary == nullptr) {
return;
}

MemoryIOReader reader(binary->data.get(), binary->size);

// there are 2 possibilities for the input index binary:
// 1. native IVF_FLAT, do nothing
// 2. IVF_FLAT_NM, convert to native IVF_FLAT
try {
// try to parse as native format, if it's actually _NM format,
// faiss will raise a "read error" exception for IVF_FLAT_NM format
faiss::read_index(&reader);
} catch (faiss::FaissException& e) {
reader.reset();

// convert IVF_FLAT_NM to native IVF_FLAT
auto* index = static_cast<faiss::IndexIVFFlat*>(faiss::read_index_nm(&reader));
index->restore_codes(raw_data, raw_size);

// over-write IVF_FLAT_NM binary with native IVF_FLAT binary
MemoryIOWriter writer;
faiss::write_index(index, &writer);
std::shared_ptr<uint8_t[]> data(writer.data());
binary->data = data;
binary->size = writer.tellg();

LOG_KNOWHERE_INFO_ << "Convert IVF_FLAT_NM to native IVF_FLAT";
}
}

} // namespace knowhere
94 changes: 26 additions & 68 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class IvfIndexNode : public IndexNode {
auto nb = index_->invlists->compute_ntotal();
auto nlist = index_->nlist;
auto code_size = index_->code_size;
return (nb * code_size + nb * sizeof(int64_t) + nlist * code_size);
return ((nb + nlist) * (code_size + sizeof(int64_t)));
}
if constexpr (std::is_same<T, faiss::IndexIVFFlatCC>::value) {
auto nb = index_->invlists->compute_ntotal();
Expand Down Expand Up @@ -352,14 +352,11 @@ IvfIndexNode<T>::Add(const DataSet& dataset, const Config& cfg) {
setter = std::make_unique<ThreadPool::ScopedOmpSetter>(base_cfg.num_build_thread.value());
}
try {
if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
index_->add_without_codes(rows, (const float*)data);
} else if constexpr (std::is_same<faiss::IndexBinaryIVF, T>::value) {
if constexpr (std::is_same<faiss::IndexBinaryIVF, T>::value) {
index_->add(rows, (const uint8_t*)data);
} else {
index_->add(rows, (const float*)data);
}

} catch (std::exception& e) {
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
return Status::faiss_inner_error;
Expand Down Expand Up @@ -419,14 +416,16 @@ IvfIndexNode<T>::Search(const DataSet& dataset, const Config& cfg, const BitsetV
std::lock_guard<std::mutex> lock(normalize_mtx_);
if (!normalized_) {
faiss::IndexIVFFlat* ivf_index = static_cast<faiss::IndexIVFFlat*>(index_.get());
size_t nb = ivf_index->arranged_codes.size() / ivf_index->code_size;
NormalizeVecs((float*)(ivf_index->arranged_codes.data()), nb, dim);
auto nlist = ivf_index->invlists->nlist;
for (size_t i = 0; i < nlist; i++) {
auto list_size = ivf_index->invlists->list_size(i);
NormalizeVecs((float*)(ivf_index->invlists->get_codes(i)), list_size, dim);
}
normalized_ = true;
}
}
}
index_->search_without_codes_thread_safe(1, cur_query, k, distances + offset, ids + offset, nprobe,
0, bitset);
index_->search_thread_safe(1, cur_query, k, distances + offset, ids + offset, nprobe, 0, bitset);
} else if constexpr (std::is_same<T, faiss::IndexScaNN>::value) {
auto cur_query = (const float*)data + index * dim;
const ScannConfig& scann_cfg = static_cast<const ScannConfig&>(cfg);
Expand Down Expand Up @@ -514,14 +513,16 @@ IvfIndexNode<T>::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi
std::lock_guard<std::mutex> lock(normalize_mtx_);
if (!normalized_) {
faiss::IndexIVFFlat* ivf_index = static_cast<faiss::IndexIVFFlat*>(index_.get());
size_t nb = ivf_index->arranged_codes.size() / ivf_index->code_size;
NormalizeVecs((float*)(ivf_index->arranged_codes.data()), nb, dim);
auto nlist = ivf_index->invlists->nlist;
for (size_t i = 0; i < nlist; i++) {
auto list_size = ivf_index->invlists->list_size(i);
NormalizeVecs((float*)(ivf_index->invlists->get_codes(i)), list_size, dim);
}
normalized_ = true;
}
}
}
index_->range_search_without_codes_thread_safe(1, cur_query, radius, &res, index_->nlist, 0,
bitset);
index_->range_search_thread_safe(1, cur_query, radius, &res, index_->nlist, 0, bitset);
} else if constexpr (std::is_same<T, faiss::IndexScaNN>::value) {
auto cur_query = (const float*)xq + index * dim;
if (is_cosine) {
Expand Down Expand Up @@ -592,27 +593,7 @@ IvfIndexNode<T>::GetVectorByIds(const DataSet& dataset) const {
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
return expected<DataSetPtr>::Err(Status::faiss_inner_error, e.what());
}
} else if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
auto dim = Dim();
auto rows = dataset.GetRows();
auto ids = dataset.GetIds();

float* data = nullptr;
try {
data = new float[dim * rows];
index_->make_direct_map(true);
for (int64_t i = 0; i < rows; i++) {
int64_t id = ids[i];
assert(id >= 0 && id < index_->ntotal);
index_->reconstruct_without_codes(id, data + i * dim);
}
return GenResultDataSet(rows, dim, data);
} catch (const std::exception& e) {
std::unique_ptr<float[]> auto_del(data);
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
return expected<DataSetPtr>::Err(Status::faiss_inner_error, e.what());
}
} else if constexpr (std::is_same<T, faiss::IndexIVFFlatCC>::value) {
} else if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value || std::is_same<T, faiss::IndexIVFFlatCC>::value) {
auto dim = Dim();
auto rows = dataset.GetRows();
auto ids = dataset.GetIds();
Expand Down Expand Up @@ -675,7 +656,7 @@ IvfIndexNode<faiss::IndexIVFFlat>::GetIndexMeta(const Config& config) const {
std::unordered_set<int64_t> id_set;

for (int32_t i = 0; i < nlist; i++) {
// copy from IndexIVF::search_preassigned_without_codes
// copy from IndexIVF::search_preassigned
std::unique_ptr<faiss::InvertedLists::ScopedIds> sids =
std::make_unique<faiss::InvertedLists::ScopedIds>(index_->invlists, i);

Expand All @@ -702,8 +683,6 @@ IvfIndexNode<T>::Serialize(BinarySet& binset) const {
MemoryIOWriter writer;
if constexpr (std::is_same<T, faiss::IndexBinaryIVF>::value) {
faiss::write_index_binary(index_.get(), &writer);
} else if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
faiss::write_index_nm(index_.get(), &writer);
} else {
faiss::write_index(index_.get(), &writer);
}
Expand All @@ -730,7 +709,16 @@ IvfIndexNode<T>::Deserialize(const BinarySet& binset, const Config& config) {

MemoryIOReader reader(binary->data.get(), binary->size);
try {
if constexpr (std::is_same<T, faiss::IndexBinaryIVF>::value) {
if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
auto raw_binary = binset.GetByName("RAW_DATA");
if (raw_binary != nullptr) {
ConvertIVFFlatIfNeeded(binset, raw_binary->data.get(), raw_binary->size);
// after conversion, binary size and data will be updated
reader.data_ = binary->data.get();
reader.total_ = binary->size;
}
index_.reset(static_cast<faiss::IndexIVFFlat*>(faiss::read_index(&reader)));
} else if constexpr (std::is_same<T, faiss::IndexBinaryIVF>::value) {
index_.reset(static_cast<T*>(faiss::read_index_binary(&reader)));
} else {
index_.reset(static_cast<T*>(faiss::read_index(&reader)));
Expand Down Expand Up @@ -764,36 +752,6 @@ IvfIndexNode<T>::DeserializeFromFile(const std::string& filename, const Config&
return Status::success;
}

template <>
Status
IvfIndexNode<faiss::IndexIVFFlat>::Deserialize(const BinarySet& binset, const Config& config) {
std::vector<std::string> names = {"IVF", // compatible with knowhere-1.x
Type()};
auto binary = binset.GetByNames(names);
if (binary == nullptr) {
LOG_KNOWHERE_ERROR_ << "Invalid binary set.";
return Status::invalid_binary_set;
}

MemoryIOReader reader(binary->data.get(), binary->size);
try {
index_.reset(static_cast<faiss::IndexIVFFlat*>(faiss::read_index_nm(&reader)));

// Construct arranged data from original data
auto binary = binset.GetByName("RAW_DATA");
if (binary == nullptr) {
LOG_KNOWHERE_ERROR_ << "Invalid binary set.";
return Status::invalid_binary_set;
}
size_t nb = binary->size / index_->invlists->code_size;
index_->arrange_codes(nb, (const float*)(binary->data.get()));
} catch (const std::exception& e) {
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
return Status::faiss_inner_error;
}
return Status::success;
}

KNOWHERE_REGISTER_GLOBAL(IVFBIN, [](const Object& object) {
return Index<IvfIndexNode<faiss::IndexBinaryIVF>>::Create(object);
});
Expand Down
18 changes: 0 additions & 18 deletions tests/ut/test_feder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,22 +167,6 @@ TEST_CASE("Test Feder", "[feder]") {
return json;
};

auto load_raw_data = [](knowhere::Index<knowhere::IndexNode>& index, const knowhere::DataSet& dataset,
const knowhere::Json& conf) {
auto rows = dataset.GetRows();
auto dim = dataset.GetDim();
auto p_data = dataset.GetTensor();
knowhere::BinarySet bs;
auto res = index.Serialize(bs);
REQUIRE(res == knowhere::Status::success);
knowhere::BinaryPtr bptr = std::make_shared<knowhere::Binary>();
bptr->data = std::shared_ptr<uint8_t[]>((uint8_t*)p_data, [&](uint8_t*) {});
bptr->size = dim * rows * sizeof(float);
bs.Append("RAW_DATA", bptr);
res = index.Deserialize(bs);
REQUIRE(res == knowhere::Status::success);
};

const auto train_ds = GenDataSet(nb, dim, seed);
const auto query_ds = GenDataSet(nq, dim, seed);

Expand Down Expand Up @@ -222,8 +206,6 @@ TEST_CASE("Test Feder", "[feder]") {
auto res = idx.Build(*train_ds, json);
REQUIRE(res == knowhere::Status::success);

load_raw_data(idx, *train_ds, json);

auto res1 = idx.GetIndexMeta(json);
REQUIRE(res1.has_value());
CheckIvfFlatMeta(res1.value(), nb, json);
Expand Down
19 changes: 0 additions & 19 deletions tests/ut/test_get_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,22 +139,6 @@ TEST_CASE("Test Float Get Vector By Ids", "[Float GetVectorByIds]") {

auto flat_gen = base_gen;

auto load_raw_data = [](knowhere::Index<knowhere::IndexNode>& index, const knowhere::DataSet& dataset,
const knowhere::Json& conf) {
auto rows = dataset.GetRows();
auto dim = dataset.GetDim();
auto p_data = dataset.GetTensor();
knowhere::BinarySet bs;
auto res = index.Serialize(bs);
REQUIRE(res == knowhere::Status::success);
knowhere::BinaryPtr bptr = std::make_shared<knowhere::Binary>();
bptr->data = std::shared_ptr<uint8_t[]>((uint8_t*)p_data, [&](uint8_t*) {});
bptr->size = dim * rows * sizeof(float);
bs.Append("RAW_DATA", bptr);
res = index.Deserialize(bs);
REQUIRE(res == knowhere::Status::success);
};

SECTION("Test float index") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
Expand Down Expand Up @@ -184,9 +168,6 @@ TEST_CASE("Test Float Get Vector By Ids", "[Float GetVectorByIds]") {

auto idx_new = knowhere::IndexFactory::Instance().Create(name);
idx_new.Deserialize(bs);
if (name == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT) {
load_raw_data(idx_new, *train_ds, json);
}
auto results = idx_new.GetVectorByIds(*ids_ds);
REQUIRE(results.has_value());
auto xb = (float*)train_ds_copy->GetTensor();
Expand Down
15 changes: 0 additions & 15 deletions tests/ut/test_mmap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,6 @@ TEST_CASE("Search mmap", "[float metrics]") {
auto data = bs.GetByName(index.Type());

WriteDataToDisk(path.string(), reinterpret_cast<const char*>(data->data.get()), data->size);

// knowhere::BinaryPtr bptr = std::make_shared<knowhere::Binary>();
// bptr->data = std::shared_ptr<uint8_t[]>((uint8_t*)p_data, [&](uint8_t*) {});
// bptr->size = dim * rows * sizeof(float);
// bs.Append("RAW_DATA", bptr);
REQUIRE(index.DeserializeFromFile(path, conf) == knowhere::Status::success);
};

Expand Down Expand Up @@ -262,11 +257,6 @@ TEST_CASE("Search binary mmap", "[float metrics]") {
auto data = bs.GetByName(index.Type());

WriteDataToDisk(path.string(), reinterpret_cast<const char*>(data->data.get()), data->size);

// knowhere::BinaryPtr bptr = std::make_shared<knowhere::Binary>();
// bptr->data = std::shared_ptr<uint8_t[]>((uint8_t*)p_data, [&](uint8_t*) {});
// bptr->size = dim * rows * sizeof(float);
// bs.Append("RAW_DATA", bptr);
REQUIRE(index.DeserializeFromFile(path, conf) == knowhere::Status::success);
};

Expand Down Expand Up @@ -373,11 +363,6 @@ TEST_CASE("Search binary mmap", "[bool metrics]") {
auto data = bs.GetByName(index.Type());

WriteDataToDisk(path.string(), reinterpret_cast<const char*>(data->data.get()), data->size);

// knowhere::BinaryPtr bptr = std::make_shared<knowhere::Binary>();
// bptr->data = std::shared_ptr<uint8_t[]>((uint8_t*)p_data, [&](uint8_t*) {});
// bptr->size = dim * rows * sizeof(float);
// bs.Append("RAW_DATA", bptr);
REQUIRE(index.DeserializeFromFile(path, conf) == knowhere::Status::success);
};

Expand Down
Loading

0 comments on commit dbd7756

Please sign in to comment.