From dcf7832ac8ad02eed146861295073da686c70de0 Mon Sep 17 00:00:00 2001 From: chasingegg Date: Fri, 24 May 2024 18:58:15 +0800 Subject: [PATCH] fix: correct get vector data size for bf16/fp16/binary vector Signed-off-by: chasingegg --- internal/core/src/common/Types.h | 1 + internal/core/src/index/VectorDiskIndex.cpp | 4 +-- internal/core/src/index/VectorMemIndex.cpp | 6 ++-- internal/core/unittest/test_float16.cpp | 32 ++++++++++----------- 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index e22f1e230e40d..4d145778287c5 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -57,6 +57,7 @@ using distance_t = float; using float16 = knowhere::fp16; using bfloat16 = knowhere::bf16; +using bin1 = knowhere::bin1; enum class DataType { NONE = 0, diff --git a/internal/core/src/index/VectorDiskIndex.cpp b/internal/core/src/index/VectorDiskIndex.cpp index 5e3f23dd87c06..73811b5077c0e 100644 --- a/internal/core/src/index/VectorDiskIndex.cpp +++ b/internal/core/src/index/VectorDiskIndex.cpp @@ -449,10 +449,10 @@ VectorDiskAnnIndex::GetVector(const DatasetPtr dataset) const { auto row_num = res.value()->GetRows(); auto dim = res.value()->GetDim(); int64_t data_size; - if (is_in_bin_list(index_type)) { + if constexpr (std::is_same_v) { data_size = dim / 8 * row_num; } else { - data_size = dim * row_num * sizeof(float); + data_size = dim * row_num * sizeof(T); } std::vector raw_data; raw_data.resize(data_size); diff --git a/internal/core/src/index/VectorMemIndex.cpp b/internal/core/src/index/VectorMemIndex.cpp index fabbface68fc9..580c568e10bde 100644 --- a/internal/core/src/index/VectorMemIndex.cpp +++ b/internal/core/src/index/VectorMemIndex.cpp @@ -669,10 +669,10 @@ VectorMemIndex::GetVector(const DatasetPtr dataset) const { auto row_num = res.value()->GetRows(); auto dim = res.value()->GetDim(); int64_t data_size; - if (is_in_bin_list(index_type)) { + if constexpr (std::is_same_v) { data_size = dim / 8 * row_num; } else { - data_size = dim * row_num * sizeof(float); + data_size = dim * row_num * sizeof(T); } std::vector raw_data; raw_data.resize(data_size); @@ -954,7 +954,7 @@ VectorMemIndex::LoadFromFileV2(const Config& config) { LOG_INFO("load vector index done"); } template class VectorMemIndex; -template class VectorMemIndex; +template class VectorMemIndex; template class VectorMemIndex; template class VectorMemIndex; diff --git a/internal/core/unittest/test_float16.cpp b/internal/core/unittest/test_float16.cpp index 04ca348d45d80..bf172a5d47bcd 100644 --- a/internal/core/unittest/test_float16.cpp +++ b/internal/core/unittest/test_float16.cpp @@ -196,14 +196,14 @@ TEST(Float16, GetVector) { auto vector = result.get()->mutable_vectors()->float16_vector(); EXPECT_TRUE(vector.size() == num_inserted * dim * sizeof(float16)); - // EXPECT_TRUE(vector.size() == num_inserted * dim); - // for (size_t i = 0; i < num_inserted; ++i) { - // auto id = ids_ds->GetIds()[i]; - // for (size_t j = 0; j < 128; ++j) { - // EXPECT_TRUE(vector[i * dim + j] == - // fakevec[(id % per_batch) * dim + j]); - // } - // } + for (size_t i = 0; i < num_inserted; ++i) { + auto id = ids_ds->GetIds()[i]; + for (size_t j = 0; j < 128; ++j) { + EXPECT_TRUE( + reinterpret_cast(vector.data())[i * dim + j] == + fakevec[(id % per_batch) * dim + j]); + } + } } } @@ -453,14 +453,14 @@ TEST(BFloat16, GetVector) { auto vector = result.get()->mutable_vectors()->bfloat16_vector(); EXPECT_TRUE(vector.size() == num_inserted * dim * sizeof(bfloat16)); - // EXPECT_TRUE(vector.size() == num_inserted * dim); - // for (size_t i = 0; i < num_inserted; ++i) { - // auto id = ids_ds->GetIds()[i]; - // for (size_t j = 0; j < 128; ++j) { - // EXPECT_TRUE(vector[i * dim + j] == - // fakevec[(id % per_batch) * dim + j]); - // } - // } + for (size_t i = 0; i < num_inserted; ++i) { + auto id = ids_ds->GetIds()[i]; + for (size_t j = 0; j < 128; ++j) { + EXPECT_TRUE( + reinterpret_cast(vector.data())[i * dim + j] == + fakevec[(id % per_batch) * dim + j]); + } + } } }