Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Add binary unittest to improve code coverage
Browse files Browse the repository at this point in the history
Signed-off-by: Yudong Cai <yudong.cai@zilliz.com>
  • Loading branch information
cydrain committed Jun 29, 2023
1 parent 4916757 commit f68b445
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 45 deletions.
29 changes: 17 additions & 12 deletions tests/ut/test_get_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
TEST_CASE("Test Binary Get Vector By Ids", "[Binary GetVectorByIds]") {
using Catch::Approx;

int64_t nb = 10000;
int64_t dim = 128;
int64_t seed = 42;
const int64_t nb = 1000;
const int64_t dim = 128;

const auto metric_type = knowhere::metric::HAMMING;

auto base_bin_gen = [&]() {
knowhere::Json json;
json[knowhere::meta::DIM] = dim;
json[knowhere::meta::METRIC_TYPE] = knowhere::metric::HAMMING;
json[knowhere::meta::METRIC_TYPE] = metric_type;
json[knowhere::meta::TOPK] = 1;
return json;
};
Expand All @@ -48,10 +49,13 @@ TEST_CASE("Test Binary Get Vector By Ids", "[Binary GetVectorByIds]") {
make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, bin_ivfflat_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create(name);
if (!idx.HasRawData(metric_type)) {
return;
}
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
auto train_ds = GenBinDataSet(nb, dim, seed);
auto train_ds = GenBinDataSet(nb, dim);
auto ids_ds = GenIdsDataSet(nb, dim);
REQUIRE(idx.Type() == name);
auto res = idx.Build(*train_ds, json);
Expand Down Expand Up @@ -82,9 +86,8 @@ TEST_CASE("Test Binary Get Vector By Ids", "[Binary GetVectorByIds]") {
TEST_CASE("Test Float Get Vector By Ids", "[Float GetVectorByIds]") {
using Catch::Approx;

int64_t nb = 10000;
int64_t dim = 128;
int64_t seed = 42;
const int64_t nb = 1000;
const int64_t dim = 128;

auto metric = GENERATE(as<std::string>{}, knowhere::metric::L2, knowhere::metric::COSINE);

Expand Down Expand Up @@ -141,13 +144,18 @@ TEST_CASE("Test Float Get Vector By Ids", "[Float GetVectorByIds]") {
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IDMAP, flat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivfflat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, base_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create(name);
if (!idx.HasRawData(metric)) {
return;
}
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
auto train_ds = GenDataSet(nb, dim, seed);
auto train_ds = GenDataSet(nb, dim);
auto train_ds_copy = CopyDataSet(train_ds, nb);
auto ids_ds = GenIdsDataSet(nb, dim);
REQUIRE(idx.Type() == name);
Expand All @@ -161,9 +169,6 @@ TEST_CASE("Test Float Get Vector By Ids", "[Float GetVectorByIds]") {
if (name == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT) {
load_raw_data(idx_new, *train_ds, json);
}
if (!idx_new.HasRawData(metric)) {
return;
}
auto results = idx_new.GetVectorByIds(*ids_ds);
REQUIRE(results.has_value());
auto xb = (float*)train_ds_copy->GetTensor();
Expand Down
2 changes: 1 addition & 1 deletion tests/ut/test_knowhere_init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ TEST_CASE("Knowhere global config", "[init]") {
knowhere::KnowhereConfig::SetEarlyStopThreshold(early_stop_threshold);
REQUIRE(knowhere::KnowhereConfig::GetEarlyStopThreshold() == early_stop_threshold);

knowhere::KnowhereConfig::SetClusteringType(knowhere::KnowhereConfig::ClusteringType::K_MEANS);
knowhere::KnowhereConfig::SetClusteringType(knowhere::KnowhereConfig::ClusteringType::K_MEANS_PLUS_PLUS);
knowhere::KnowhereConfig::SetClusteringType(knowhere::KnowhereConfig::ClusteringType::K_MEANS);

#ifdef KNOWHERE_WITH_DISKANN
knowhere::KnowhereConfig::SetAioContextPool(128, 2048);
Expand Down
165 changes: 133 additions & 32 deletions tests/ut/test_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,26 @@
#include "utils.h"

namespace {
constexpr float kKnnRecallThreshold = 0.75f;
constexpr float kKnnRecallThreshold = 0.6f;
constexpr float kBruteForceRecallThreshold = 0.99f;
constexpr size_t kTopk = 1;
} // namespace

TEST_CASE("Test All Mem Index Search", "[search]") {
TEST_CASE("Test Mem Index With Float Vector", "[float vector]") {
using Catch::Approx;

int64_t nb = 10000, nq = 100;
int64_t dim = 128;
int64_t seed = 42;
const int64_t nb = 1000, nq = 10;
const int64_t dim = 128;
const int64_t topk = 5;

auto metric = GENERATE(as<std::string>{}, knowhere::metric::L2, knowhere::metric::COSINE);

auto base_gen = [&]() {
knowhere::Json json;
json[knowhere::meta::DIM] = dim;
json[knowhere::meta::METRIC_TYPE] = metric;
json[knowhere::meta::TOPK] = kTopk;
json[knowhere::meta::TOPK] = topk;
json[knowhere::meta::RADIUS] = knowhere::IsMetricType(metric, knowhere::metric::L2) ? 10.0 : 0.99;
json[knowhere::meta::RANGE_FILTER] = knowhere::IsMetricType(metric, knowhere::metric::L2) ? 0.0 : 1.01;
return json;
};

Expand Down Expand Up @@ -83,24 +83,24 @@ TEST_CASE("Test All Mem Index Search", "[search]") {
auto dim = dataset.GetDim();
auto p_data = dataset.GetTensor();
knowhere::BinarySet bs;
auto res = index.Serialize(bs);
REQUIRE(res == knowhere::Status::success);
REQUIRE(index.Serialize(bs) == knowhere::Status::success);
knowhere::BinaryPtr bptr = std::make_shared<knowhere::Binary>();
bptr->data = std::shared_ptr<uint8_t[]>((uint8_t*)p_data, [&](uint8_t*) {});
bptr->size = dim * rows * sizeof(float);
bs.Append("RAW_DATA", bptr);
res = index.Deserialize(bs);
REQUIRE(res == knowhere::Status::success);
REQUIRE(index.Deserialize(bs) == knowhere::Status::success);
};

const auto train_ds = GenDataSet(nb, dim, seed);
const auto query_ds = GenDataSet(nq, dim, seed);
const auto train_ds = GenDataSet(nb, dim);
const auto query_ds = GenDataSet(nq, dim);

const knowhere::Json conf = {
{knowhere::meta::METRIC_TYPE, metric},
{knowhere::meta::TOPK, kTopk},
{knowhere::meta::TOPK, topk},
};
auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr);
SECTION("Test Cpu Index Search") {

SECTION("Test Search") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IDMAP, flat_gen),
Expand All @@ -115,47 +115,51 @@ TEST_CASE("Test All Mem Index Search", "[search]") {
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
auto res = idx.Build(*train_ds, json);
REQUIRE(res == knowhere::Status::success);
REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success);
REQUIRE(idx.Size() > 0);
REQUIRE(idx.Count() == nb);
if (name == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT) {
load_raw_data(idx, *train_ds, json);
}
auto results = idx.Search(*query_ds, json, nullptr);
REQUIRE(results.has_value());
float recall = GetKNNRecall(*gt.value(), *results.value());
REQUIRE(recall > kKnnRecallThreshold);
if (name != "IVF_PQ") {
REQUIRE(recall > kKnnRecallThreshold);
}
}

SECTION("Test Cpu Index Range Search") {
SECTION("Test Range Search") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IDMAP, flat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivfflat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivfsq_gen),
// make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, ivfpq_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, ivfpq_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create(name);
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
auto res = idx.Build(*train_ds, json);
REQUIRE(res == knowhere::Status::success);
REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success);
if (name == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT) {
load_raw_data(idx, *train_ds, json);
}
auto results = idx.RangeSearch(*query_ds, json, nullptr);
REQUIRE(results.has_value());
auto ids = results.value()->GetIds();
auto lims = results.value()->GetLims();
for (int i = 0; i < nq; ++i) {
CHECK(ids[lims[i]] == i);
if (name != "IVF_PQ") {
for (int i = 0; i < nq; ++i) {
CHECK(ids[lims[i]] == i);
}
}
}

SECTION("Test Cpu Index Search with Bitset") {
SECTION("Test Search with Bitset") {
using std::make_tuple;
auto [name, gen, threshold] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>, float>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold),
Expand All @@ -165,8 +169,7 @@ TEST_CASE("Test All Mem Index Search", "[search]") {
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
auto res = idx.Build(*train_ds, json);
REQUIRE(res == knowhere::Status::success);
REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success);

std::vector<std::function<std::vector<uint8_t>(size_t, size_t)>> gen_bitset_funcs = {
GenerateBitsetWithFirstTbitsSet, GenerateBitsetWithRandomTbitsSet};
Expand All @@ -187,7 +190,7 @@ TEST_CASE("Test All Mem Index Search", "[search]") {
}
}

SECTION("Test Cpu Index Serialize/Deserialize") {
SECTION("Test Serialize/Deserialize") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IDMAP, flat_gen),
Expand All @@ -203,8 +206,7 @@ TEST_CASE("Test All Mem Index Search", "[search]") {
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
auto res = idx.Build(*train_ds, json);
REQUIRE(res == knowhere::Status::success);
REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success);
knowhere::BinarySet bs;
idx.Serialize(bs);

Expand All @@ -217,7 +219,7 @@ TEST_CASE("Test All Mem Index Search", "[search]") {
REQUIRE(results.has_value());
}

SECTION("Test build IVFPQ with invalid params") {
SECTION("Test IVFPQ with invalid params") {
auto idx = knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFPQ);
uint32_t nb = 1000;
uint32_t dim = 128;
Expand All @@ -230,8 +232,107 @@ TEST_CASE("Test All Mem Index Search", "[search]") {
json[knowhere::indexparam::NBITS] = 8;
return json;
};
auto train_ds = GenDataSet(nb, dim, seed);
auto train_ds = GenDataSet(nb, dim);
auto res = idx.Build(*train_ds, ivf_pq_gen());
REQUIRE(res == knowhere::Status::faiss_inner_error);
}
}

TEST_CASE("Test Mem Index With Binary Vector", "[binary vector]") {
using Catch::Approx;

const int64_t nb = 1000, nq = 10;
const int64_t dim = 1024;
const int64_t topk = 5;

auto metric = GENERATE(as<std::string>{}, knowhere::metric::HAMMING);

auto base_gen = [&]() {
knowhere::Json json;
json[knowhere::meta::DIM] = dim;
json[knowhere::meta::METRIC_TYPE] = metric;
json[knowhere::meta::TOPK] = topk;
json[knowhere::meta::RADIUS] = 10.0;
json[knowhere::meta::RANGE_FILTER] = 0.0;
return json;
};

auto flat_gen = base_gen;
auto ivfflat_gen = [&base_gen]() {
knowhere::Json json = base_gen();
json[knowhere::indexparam::NLIST] = 16;
json[knowhere::indexparam::NPROBE] = 8;
return json;
};

const auto train_ds = GenBinDataSet(nb, dim);
const auto query_ds = GenBinDataSet(nq, dim);
const knowhere::Json conf = {
{knowhere::meta::METRIC_TYPE, metric},
{knowhere::meta::TOPK, topk},
};

auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr);
SECTION("Test Search") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, flat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create(name);
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success);
REQUIRE(idx.Size() > 0);
REQUIRE(idx.Count() == nb);
auto results = idx.Search(*query_ds, json, nullptr);
REQUIRE(results.has_value());
float recall = GetKNNRecall(*gt.value(), *results.value());
REQUIRE(recall > kKnnRecallThreshold);
}

SECTION("Test Range Search") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, flat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create(name);
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success);
auto results = idx.RangeSearch(*query_ds, json, nullptr);
REQUIRE(results.has_value());
auto ids = results.value()->GetIds();
auto lims = results.value()->GetLims();
for (int i = 0; i < nq; ++i) {
CHECK(ids[lims[i]] == i);
}
}

SECTION("Test Serialize/Deserialize") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, flat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen),
}));

auto idx = knowhere::IndexFactory::Instance().Create(name);
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success);
knowhere::BinarySet bs;
idx.Serialize(bs);

auto idx_ = knowhere::IndexFactory::Instance().Create(name);
idx_.Deserialize(bs);
auto results = idx_.Search(*query_ds, json, nullptr);
REQUIRE(results.has_value());
}
}

0 comments on commit f68b445

Please sign in to comment.