diff --git a/faiss/utils/Heap.cpp b/faiss/utils/Heap.cpp index 63981ba6dc..16ee23f6c8 100644 --- a/faiss/utils/Heap.cpp +++ b/faiss/utils/Heap.cpp @@ -93,7 +93,7 @@ void HeapArray::addn_query_subset_with_ids( } #pragma omp parallel for if (nsubset * nj > 100000) for (int64_t si = 0; si < nsubset; si++) { - T i = subset[si]; + TI i = subset[si]; T* __restrict simi = get_val(i); TI* __restrict idxi = get_ids(i); const T* ip_line = vin + si * nj; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index dfe8c13394..8167c7dfc7 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -26,6 +26,7 @@ set(FAISS_TEST_SRC test_approx_topk.cpp test_RCQ_cropping.cpp test_distances_simd.cpp + test_heap.cpp ) add_executable(faiss_test ${FAISS_TEST_SRC}) diff --git a/tests/test_heap.cpp b/tests/test_heap.cpp new file mode 100644 index 0000000000..9481003db6 --- /dev/null +++ b/tests/test_heap.cpp @@ -0,0 +1,53 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include + +using namespace faiss; + +TEST(Heap, addn_with_ids) { + size_t n = 1000; + size_t k = 1; + std::vector heap_labels(n, -1); + std::vector heap_distances(n, 0); + float_minheap_array_t heaps = { + n, k, heap_labels.data(), heap_distances.data()}; + heaps.heapify(); + std::vector labels(n, 1); + std::vector distances(n, 0.0f); + std::vector subset(n); + std::iota(subset.begin(), subset.end(), 0); + heaps.addn_with_ids(1, distances.data(), labels.data(), 1); + heaps.reorder(); + EXPECT_TRUE( + std::all_of(heap_labels.begin(), heap_labels.end(), [](int64_t i) { + return i == 1; + })); +} + +TEST(Heap, addn_query_subset_with_ids) { + size_t n = 20000000; // more than 2^24 + size_t k = 1; + std::vector heap_labels(n, -1); + std::vector heap_distances(n, 0); + float_minheap_array_t heaps = { + n, k, heap_labels.data(), heap_distances.data()}; + heaps.heapify(); + std::vector labels(n, 1); + std::vector distances(n, 0.0f); + std::vector subset(n); + std::iota(subset.begin(), subset.end(), 0); + heaps.addn_query_subset_with_ids( + n, subset.data(), 1, distances.data(), labels.data(), 1); + heaps.reorder(); + EXPECT_TRUE( + std::all_of(heap_labels.begin(), heap_labels.end(), [](int64_t i) { + return i == 1; + })); +}