Skip to content

Commit

Permalink
fix: correct get vector data size for bf16/fp16/binary vector (#33377)
Browse files Browse the repository at this point in the history
related #22837

Signed-off-by: chasingegg <chao.gao@zilliz.com>
  • Loading branch information
chasingegg authored Jun 5, 2024
1 parent 597f4c5 commit 545d472
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 21 deletions.
1 change: 1 addition & 0 deletions internal/core/src/common/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ using distance_t = float;

using float16 = knowhere::fp16;
using bfloat16 = knowhere::bf16;
using bin1 = knowhere::bin1;

enum class DataType {
NONE = 0,
Expand Down
4 changes: 2 additions & 2 deletions internal/core/src/index/VectorDiskIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,10 +449,10 @@ VectorDiskAnnIndex<T>::GetVector(const DatasetPtr dataset) const {
auto row_num = res.value()->GetRows();
auto dim = res.value()->GetDim();
int64_t data_size;
if (is_in_bin_list(index_type)) {
if constexpr (std::is_same_v<T, bin1>) {
data_size = dim / 8 * row_num;
} else {
data_size = dim * row_num * sizeof(float);
data_size = dim * row_num * sizeof(T);
}
std::vector<uint8_t> raw_data;
raw_data.resize(data_size);
Expand Down
6 changes: 3 additions & 3 deletions internal/core/src/index/VectorMemIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -669,10 +669,10 @@ VectorMemIndex<T>::GetVector(const DatasetPtr dataset) const {
auto row_num = res.value()->GetRows();
auto dim = res.value()->GetDim();
int64_t data_size;
if (is_in_bin_list(index_type)) {
if constexpr (std::is_same_v<T, bin1>) {
data_size = dim / 8 * row_num;
} else {
data_size = dim * row_num * sizeof(float);
data_size = dim * row_num * sizeof(T);
}
std::vector<uint8_t> raw_data;
raw_data.resize(data_size);
Expand Down Expand Up @@ -954,7 +954,7 @@ VectorMemIndex<T>::LoadFromFileV2(const Config& config) {
LOG_INFO("load vector index done");
}
template class VectorMemIndex<float>;
template class VectorMemIndex<uint8_t>;
template class VectorMemIndex<bin1>;
template class VectorMemIndex<float16>;
template class VectorMemIndex<bfloat16>;

Expand Down
32 changes: 16 additions & 16 deletions internal/core/unittest/test_float16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,14 @@ TEST(Float16, GetVector) {

auto vector = result.get()->mutable_vectors()->float16_vector();
EXPECT_TRUE(vector.size() == num_inserted * dim * sizeof(float16));
// EXPECT_TRUE(vector.size() == num_inserted * dim);
// for (size_t i = 0; i < num_inserted; ++i) {
// auto id = ids_ds->GetIds()[i];
// for (size_t j = 0; j < 128; ++j) {
// EXPECT_TRUE(vector[i * dim + j] ==
// fakevec[(id % per_batch) * dim + j]);
// }
// }
for (size_t i = 0; i < num_inserted; ++i) {
auto id = ids_ds->GetIds()[i];
for (size_t j = 0; j < 128; ++j) {
EXPECT_TRUE(
reinterpret_cast<float16*>(vector.data())[i * dim + j] ==
fakevec[(id % per_batch) * dim + j]);
}
}
}
}

Expand Down Expand Up @@ -453,14 +453,14 @@ TEST(BFloat16, GetVector) {

auto vector = result.get()->mutable_vectors()->bfloat16_vector();
EXPECT_TRUE(vector.size() == num_inserted * dim * sizeof(bfloat16));
// EXPECT_TRUE(vector.size() == num_inserted * dim);
// for (size_t i = 0; i < num_inserted; ++i) {
// auto id = ids_ds->GetIds()[i];
// for (size_t j = 0; j < 128; ++j) {
// EXPECT_TRUE(vector[i * dim + j] ==
// fakevec[(id % per_batch) * dim + j]);
// }
// }
for (size_t i = 0; i < num_inserted; ++i) {
auto id = ids_ds->GetIds()[i];
for (size_t j = 0; j < 128; ++j) {
EXPECT_TRUE(
reinterpret_cast<bfloat16*>(vector.data())[i * dim + j] ==
fakevec[(id % per_batch) * dim + j]);
}
}
}
}

Expand Down

0 comments on commit 545d472

Please sign in to comment.