Skip to content

Commit

Permalink
Remove IVF_FLAT_NM
Browse files Browse the repository at this point in the history
Signed-off-by: Yudong Cai <yudong.cai@zilliz.com>
  • Loading branch information
cydrain committed Sep 1, 2023
1 parent 3d21950 commit 69903a7
Show file tree
Hide file tree
Showing 32 changed files with 141 additions and 1,576 deletions.
2 changes: 1 addition & 1 deletion benchmark/hdf5/benchmark_binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class Benchmark_binary : public Benchmark_knowhere, public ::testing::Test {
parse_ann_test_name();
load_hdf5_data<true>();

assert(metric_str_ == METRIC_HAM_STR || metric_str_ == METRIC_JAC_STR || metric_str_ == METRIC_TAN_STR);
assert(metric_str_ == METRIC_HAM_STR || metric_str_ == METRIC_JAC_STR);
metric_type_ = (metric_str_ == METRIC_HAM_STR) ? knowhere::metric::HAMMING : knowhere::metric::JACCARD;
cfg_[knowhere::meta::METRIC_TYPE] = metric_type_;
knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2);
Expand Down
2 changes: 1 addition & 1 deletion benchmark/hdf5/benchmark_binary_range.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class Benchmark_binary_range : public Benchmark_knowhere, public ::testing::Test
load_hdf5_data_range<true>();
#endif

assert(metric_str_ == METRIC_HAM_STR || metric_str_ == METRIC_JAC_STR || metric_str_ == METRIC_TAN_STR);
assert(metric_str_ == METRIC_HAM_STR || metric_str_ == METRIC_JAC_STR);
metric_type_ = (metric_str_ == METRIC_HAM_STR) ? knowhere::metric::HAMMING : knowhere::metric::JACCARD;
cfg_[knowhere::meta::METRIC_TYPE] = metric_type_;
cfg_[knowhere::meta::RADIUS] = *gt_radius_;
Expand Down
2 changes: 1 addition & 1 deletion benchmark/hdf5/benchmark_float_range_bitset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Benchmark_float_range_bitset : public Benchmark_knowhere, public ::testing
auto conf = cfg;
auto radius = conf[knowhere::meta::RADIUS].get<float>();

printf("\n[%0.3f s] %s | %s | radius=.3f\n", get_time_diff(), ann_test_name_.c_str(), index_type_.c_str(),
printf("\n[%0.3f s] %s | %s | radius=%.3f\n", get_time_diff(), ann_test_name_.c_str(), index_type_.c_str(),
radius);
printf("================================================================================\n");
for (auto per : PERCENTs_) {
Expand Down
1 change: 0 additions & 1 deletion benchmark/hdf5/benchmark_hdf5.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ static const char* METRIC_IP_STR = "angular";
static const char* METRIC_L2_STR = "euclidean";
static const char* METRIC_HAM_STR = "hamming";
static const char* METRIC_JAC_STR = "jaccard";
static const char* METRIC_TAN_STR = "tanimoto";

/************************************************************************************
* https://github.com/erikbern/ann-benchmarks
Expand Down
11 changes: 5 additions & 6 deletions benchmark/hdf5/benchmark_knowhere.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "knowhere/config.h"
#include "knowhere/factory.h"
#include "knowhere/index.h"
#include "knowhere/utils.h"

class Benchmark_knowhere : public Benchmark_hdf5 {
public:
Expand Down Expand Up @@ -73,12 +74,10 @@ class Benchmark_knowhere : public Benchmark_hdf5 {
binary_set.Append(name, data_ptr, data_size);
}

// IVFFLAT_NM should load raw data
knowhere::BinaryPtr bin = std::make_shared<knowhere::Binary>();
bin->data = std::shared_ptr<uint8_t[]>((uint8_t*)xb_, [&](uint8_t*) {});
bin->size = dim_ * nb_ * sizeof(float);
binary_set.Append("RAW_DATA", bin);

// assemble raw data for IVF_FLAT_NM
if (knowhere::IsIvfFlatNM(binary_set)) {
knowhere::AssembleWithRawData(binary_set, (uint8_t*)xb_, dim_ * nb_ * sizeof(float));
}
index.Deserialize(binary_set);
}

Expand Down
8 changes: 8 additions & 0 deletions include/knowhere/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@

#include <vector>

#include "knowhere/binaryset.h"
#include "knowhere/dataset.h"

namespace knowhere {

extern const float FloatAccuracy;
extern const char* IVF_FLAT_NM_MAGIC;

inline bool
IsMetricType(const std::string& str, const knowhere::MetricType& metric_type) {
Expand Down Expand Up @@ -65,4 +67,10 @@ round_down(const T value, const T align) {
return value / align * align;
}

bool
IsIvfFlatNM(const BinarySet& binset);

void
AssembleWithRawData(BinarySet& binset, const uint8_t* raw_data, const size_t raw_size);

} // namespace knowhere
62 changes: 62 additions & 0 deletions src/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstring>

#include "faiss/impl/FaissException.h"
#include "faiss/index_io.h"
#include "io/FaissIO.h"
#include "knowhere/log.h"
#include "simd/hook.h"

namespace knowhere {

const float FloatAccuracy = 0.00001;
const char* IVF_FLAT_NM_MAGIC = "IVF_FLAT_NM";

float
NormalizeVec(float* x, int32_t d) {
Expand Down Expand Up @@ -66,4 +71,61 @@ CopyAndNormalizeFloatVec(const float* x, int32_t dim) {
return x_norm;
}

bool
IsIvfFlatNM(const BinarySet& binset) {
std::vector<std::string> names = {"IVF", // compatible with knowhere-1.x
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT};
auto binary = binset.GetByNames(names);
if (binary == nullptr) {
LOG_KNOWHERE_ERROR_ << "Invalid binary set.";
return false;
}

bool is_nm = false;
// there are 2 possibilities for the input index binary:
// 1. IVF_FLAT_NM
// 2. native IVF_FLAT
try {
// try to parse as native format, if it's actually _NM format,
// faiss will raise a "read error" exception
MemoryIOReader reader;
reader.data_ = binary->data.get();
reader.total = binary->size;

faiss::read_index(&reader);
} catch (faiss::FaissException& e) {
is_nm = true;
}

return is_nm;
}

void
AssembleWithRawData(BinarySet& binset, const uint8_t* raw_data, const size_t raw_size) {
std::vector<std::string> names = {"IVF", // compatible with knowhere-1.x
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT};
auto binary = binset.GetByNames(names);
if (binary == nullptr) {
LOG_KNOWHERE_ERROR_ << "Invalid binary set.";
return;
}

/* append raw data
*
* (index binary) (raw data) (magic)
* |--------------------------|-----------------------------|--------|
*/
size_t magic_len = strlen(IVF_FLAT_NM_MAGIC);
size_t new_size = binary->size + raw_size + magic_len;
auto new_data = new uint8_t[new_size];
std::shared_ptr<uint8_t[]> new_ptr(new_data);
memcpy(new_data, binary->data.get(), binary->size);
memcpy(new_data + binary->size, raw_data, raw_size);
binary->size += raw_size;
memcpy(new_data + binary->size, IVF_FLAT_NM_MAGIC, magic_len);
binary->size += magic_len;

binary->data = new_ptr;
}

} // namespace knowhere
118 changes: 14 additions & 104 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class IvfIndexNode : public IndexNode {
auto nb = index_->invlists->compute_ntotal();
auto nlist = index_->nlist;
auto code_size = index_->code_size;
return (nb * code_size + nb * sizeof(int64_t) + nlist * code_size);
return ((nb + nlist) * (code_size + sizeof(int64_t)));
}
if constexpr (std::is_same<T, faiss::IndexIVFFlatCC>::value) {
auto nb = index_->invlists->compute_ntotal();
Expand Down Expand Up @@ -199,10 +199,6 @@ class IvfIndexNode : public IndexNode {
private:
std::unique_ptr<T> index_;
std::shared_ptr<ThreadPool> search_pool_;

// temporary solution to fix IVF_FLAT cosine
mutable bool normalized_ = false;
mutable std::mutex normalize_mtx_;
};

} // namespace knowhere
Expand Down Expand Up @@ -253,7 +249,6 @@ IvfIndexNode<T>::Train(const DataSet& dataset, const Config& cfg) {
if (IsMetricType(base_cfg.metric_type.value(), knowhere::metric::COSINE)) {
if constexpr (!(std::is_same_v<faiss::IndexIVFFlatCC, T>)&&!(std::is_same_v<faiss::IndexScaNN, T>)) {
Normalize(dataset);
normalized_ = true;
}
}

Expand Down Expand Up @@ -352,14 +347,11 @@ IvfIndexNode<T>::Add(const DataSet& dataset, const Config& cfg) {
setter = std::make_unique<ThreadPool::ScopedOmpSetter>(base_cfg.num_build_thread.value());
}
try {
if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
index_->add_without_codes(rows, (const float*)data);
} else if constexpr (std::is_same<faiss::IndexBinaryIVF, T>::value) {
if constexpr (std::is_same<faiss::IndexBinaryIVF, T>::value) {
index_->add(rows, (const uint8_t*)data);
} else {
index_->add(rows, (const float*)data);
}

} catch (std::exception& e) {
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
return Status::faiss_inner_error;
Expand Down Expand Up @@ -408,25 +400,6 @@ IvfIndexNode<T>::Search(const DataSet& dataset, const Config& cfg, const BitsetV
distances[i + offset] = static_cast<float>(i_distances[i + offset]);
}
}
} else if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
auto cur_query = (const float*)data + index * dim;
if (is_cosine) {
copied_query = CopyAndNormalizeFloatVec(cur_query, dim);
cur_query = copied_query.get();

// temporary solution to fix IVF_FLAT cosine
if (!normalized_) {
std::lock_guard<std::mutex> lock(normalize_mtx_);
if (!normalized_) {
faiss::IndexIVFFlat* ivf_index = static_cast<faiss::IndexIVFFlat*>(index_.get());
size_t nb = ivf_index->arranged_codes.size() / ivf_index->code_size;
NormalizeVecs((float*)(ivf_index->arranged_codes.data()), nb, dim);
normalized_ = true;
}
}
}
index_->search_without_codes_thread_safe(1, cur_query, k, distances + offset, ids + offset, nprobe,
0, bitset);
} else if constexpr (std::is_same<T, faiss::IndexScaNN>::value) {
auto cur_query = (const float*)data + index * dim;
const ScannConfig& scann_cfg = static_cast<const ScannConfig&>(cfg);
Expand Down Expand Up @@ -503,25 +476,6 @@ IvfIndexNode<T>::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi
if constexpr (std::is_same<T, faiss::IndexBinaryIVF>::value) {
auto cur_data = (const uint8_t*)xq + index * dim / 8;
index_->range_search_thread_safe(1, cur_data, radius, &res, index_->nlist, bitset);
} else if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
auto cur_query = (const float*)xq + index * dim;
if (is_cosine) {
copied_query = CopyAndNormalizeFloatVec(cur_query, dim);
cur_query = copied_query.get();

// temporary solution to fix IVF_FLAT cosine
if (!normalized_) {
std::lock_guard<std::mutex> lock(normalize_mtx_);
if (!normalized_) {
faiss::IndexIVFFlat* ivf_index = static_cast<faiss::IndexIVFFlat*>(index_.get());
size_t nb = ivf_index->arranged_codes.size() / ivf_index->code_size;
NormalizeVecs((float*)(ivf_index->arranged_codes.data()), nb, dim);
normalized_ = true;
}
}
}
index_->range_search_without_codes_thread_safe(1, cur_query, radius, &res, index_->nlist, 0,
bitset);
} else if constexpr (std::is_same<T, faiss::IndexScaNN>::value) {
auto cur_query = (const float*)xq + index * dim;
if (is_cosine) {
Expand Down Expand Up @@ -592,27 +546,7 @@ IvfIndexNode<T>::GetVectorByIds(const DataSet& dataset) const {
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
return expected<DataSetPtr>::Err(Status::faiss_inner_error, e.what());
}
} else if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
auto dim = Dim();
auto rows = dataset.GetRows();
auto ids = dataset.GetIds();

float* data = nullptr;
try {
data = new float[dim * rows];
index_->make_direct_map(true);
for (int64_t i = 0; i < rows; i++) {
int64_t id = ids[i];
assert(id >= 0 && id < index_->ntotal);
index_->reconstruct_without_codes(id, data + i * dim);
}
return GenResultDataSet(rows, dim, data);
} catch (const std::exception& e) {
std::unique_ptr<float[]> auto_del(data);
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
return expected<DataSetPtr>::Err(Status::faiss_inner_error, e.what());
}
} else if constexpr (std::is_same<T, faiss::IndexIVFFlatCC>::value) {
} else if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value || std::is_same<T, faiss::IndexIVFFlatCC>::value) {
auto dim = Dim();
auto rows = dataset.GetRows();
auto ids = dataset.GetIds();
Expand Down Expand Up @@ -702,8 +636,6 @@ IvfIndexNode<T>::Serialize(BinarySet& binset) const {
MemoryIOWriter writer;
if constexpr (std::is_same<T, faiss::IndexBinaryIVF>::value) {
faiss::write_index_binary(index_.get(), &writer);
} else if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
faiss::write_index_nm(index_.get(), &writer);
} else {
faiss::write_index(index_.get(), &writer);
}
Expand Down Expand Up @@ -732,7 +664,17 @@ IvfIndexNode<T>::Deserialize(const BinarySet& binset, const Config& config) {
reader.total = binary->size;
reader.data_ = binary->data.get();
try {
if constexpr (std::is_same<T, faiss::IndexBinaryIVF>::value) {
if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
// For IVF_FLAT_NM, AssembleWithRawData() should have been called,
// we can use IVF_FLAT_NM_MAGIC to tell it's IVF_FLAT_NM or not
size_t magic_len = strlen(IVF_FLAT_NM_MAGIC);
if (strncmp((const char*)reader.data_ + reader.total - magic_len, IVF_FLAT_NM_MAGIC, magic_len) == 0) {
index_.reset(static_cast<faiss::IndexIVFFlat*>(faiss::read_index_nm(&reader)));
index_->restore_codes(reader.data_ + reader.rp);
} else {
index_.reset(static_cast<faiss::IndexIVFFlat*>(faiss::read_index(&reader)));
}
} else if constexpr (std::is_same<T, faiss::IndexBinaryIVF>::value) {
index_.reset(static_cast<T*>(faiss::read_index_binary(&reader)));
} else {
index_.reset(static_cast<T*>(faiss::read_index(&reader)));
Expand Down Expand Up @@ -766,38 +708,6 @@ IvfIndexNode<T>::DeserializeFromFile(const std::string& filename, const Config&
return Status::success;
}

template <>
Status
IvfIndexNode<faiss::IndexIVFFlat>::Deserialize(const BinarySet& binset, const Config& config) {
std::vector<std::string> names = {"IVF", // compatible with knowhere-1.x
Type()};
auto binary = binset.GetByNames(names);
if (binary == nullptr) {
LOG_KNOWHERE_ERROR_ << "Invalid binary set.";
return Status::invalid_binary_set;
}

MemoryIOReader reader;
reader.total = binary->size;
reader.data_ = binary->data.get();
try {
index_.reset(static_cast<faiss::IndexIVFFlat*>(faiss::read_index_nm(&reader)));

// Construct arranged data from original data
auto binary = binset.GetByName("RAW_DATA");
if (binary == nullptr) {
LOG_KNOWHERE_ERROR_ << "Invalid binary set.";
return Status::invalid_binary_set;
}
size_t nb = binary->size / index_->invlists->code_size;
index_->arrange_codes(nb, (const float*)(binary->data.get()));
} catch (const std::exception& e) {
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
return Status::faiss_inner_error;
}
return Status::success;
}

KNOWHERE_REGISTER_GLOBAL(IVFBIN, [](const Object& object) {
return Index<IvfIndexNode<faiss::IndexBinaryIVF>>::Create(object);
});
Expand Down
18 changes: 0 additions & 18 deletions tests/ut/test_feder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,22 +167,6 @@ TEST_CASE("Test Feder", "[feder]") {
return json;
};

auto load_raw_data = [](knowhere::Index<knowhere::IndexNode>& index, const knowhere::DataSet& dataset,
const knowhere::Json& conf) {
auto rows = dataset.GetRows();
auto dim = dataset.GetDim();
auto p_data = dataset.GetTensor();
knowhere::BinarySet bs;
auto res = index.Serialize(bs);
REQUIRE(res == knowhere::Status::success);
knowhere::BinaryPtr bptr = std::make_shared<knowhere::Binary>();
bptr->data = std::shared_ptr<uint8_t[]>((uint8_t*)p_data, [&](uint8_t*) {});
bptr->size = dim * rows * sizeof(float);
bs.Append("RAW_DATA", bptr);
res = index.Deserialize(bs);
REQUIRE(res == knowhere::Status::success);
};

const auto train_ds = GenDataSet(nb, dim, seed);
const auto query_ds = GenDataSet(nq, dim, seed);

Expand Down Expand Up @@ -222,8 +206,6 @@ TEST_CASE("Test Feder", "[feder]") {
auto res = idx.Build(*train_ds, json);
REQUIRE(res == knowhere::Status::success);

load_raw_data(idx, *train_ds, json);

auto res1 = idx.GetIndexMeta(json);
REQUIRE(res1.has_value());
CheckIvfFlatMeta(res1.value(), nb, json);
Expand Down
Loading

0 comments on commit 69903a7

Please sign in to comment.