Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Add more unittest to improve coverage
Browse files Browse the repository at this point in the history
Signed-off-by: Yudong Cai <yudong.cai@zilliz.com>
  • Loading branch information
cydrain committed Jun 26, 2023
1 parent 2364721 commit 211b909
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 48 deletions.
9 changes: 2 additions & 7 deletions src/simd/hook.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ fvec_hook(std::string& simd_type) {
fvec_madd_and_argmin = fvec_madd_and_argmin_sse;

simd_type = "AVX512";
return;
}
if (use_avx2 && cpu_support_avx2()) {
} else if (use_avx2 && cpu_support_avx2()) {
fvec_inner_product = fvec_inner_product_avx;
fvec_L2sqr = fvec_L2sqr_avx;
fvec_L1 = fvec_L1_avx;
Expand All @@ -100,9 +98,7 @@ fvec_hook(std::string& simd_type) {
fvec_madd_and_argmin = fvec_madd_and_argmin_sse;

simd_type = "AVX2";
return;
}
if (use_sse4_2 && cpu_support_sse4_2()) {
} else if (use_sse4_2 && cpu_support_sse4_2()) {
fvec_inner_product = fvec_inner_product_sse;
fvec_L2sqr = fvec_L2sqr_sse;
fvec_L1 = fvec_L1_sse;
Expand All @@ -116,7 +112,6 @@ fvec_hook(std::string& simd_type) {

simd_type = "SSE4_2";
}

#endif
}

Expand Down
38 changes: 0 additions & 38 deletions tests/ut/test_heap.cc

This file was deleted.

42 changes: 40 additions & 2 deletions tests/ut/test_knowhere_init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,48 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.

#include <string>
#include <unordered_set>
#include <vector>

#include "catch2/catch_test_macros.hpp"
#include "knowhere/comp/knowhere_config.h"

TEST_CASE("Knowhere global config", "[init]") {
knowhere::KnowhereConfig::SetBlasThreshold(16384);
knowhere::KnowhereConfig::SetEarlyStopThreshold(0);
knowhere::KnowhereConfig::ShowVersion();

int64_t blas_threshold = 16384;
knowhere::KnowhereConfig::SetBlasThreshold(blas_threshold);
REQUIRE(knowhere::KnowhereConfig::GetBlasThreshold() == blas_threshold);

int64_t early_stop_threshold = 0;
knowhere::KnowhereConfig::SetEarlyStopThreshold(early_stop_threshold);
REQUIRE(knowhere::KnowhereConfig::GetEarlyStopThreshold() == early_stop_threshold);

knowhere::KnowhereConfig::SetClusteringType(knowhere::KnowhereConfig::ClusteringType::K_MEANS);
knowhere::KnowhereConfig::SetClusteringType(knowhere::KnowhereConfig::ClusteringType::K_MEANS_PLUS_PLUS);

#ifdef KNOWHERE_WITH_DISKANN
knowhere::KnowhereConfig::SetAioContextPool(1024, 16384);
#endif

#ifdef KNOWHERE_WITH_RAFT
knowhere::KnowhereConfig::SetRaftMemPool(1LL << 30, 2LL << 30);
#endif
}

TEST_CASE("Knowhere SIMD config", "[simd]") {
std::vector<std::string> v = {"AVX512", "AVX2", "SSE4_2", "REF"};
std::unordered_set<std::string> s(v.begin(), v.end());

auto res = knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX512);
REQUIRE(s.find(res) != s.end());
res = knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2);
REQUIRE(s.find(res) != s.end());
res = knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::SSE4_2);
REQUIRE(s.find(res) != s.end());
res = knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::GENERIC);
REQUIRE(s.find(res) != s.end());
res = knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AUTO);
REQUIRE(s.find(res) != s.end());
}
2 changes: 1 addition & 1 deletion tests/ut/test_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "utils.h"

namespace {
constexpr float kKnnRecallThreshold = 0.8f;
constexpr float kKnnRecallThreshold = 0.75f;
constexpr float kBruteForceRecallThreshold = 0.99f;
constexpr size_t kTopk = 1;
} // namespace
Expand Down
32 changes: 32 additions & 0 deletions tests/ut/test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "catch2/catch_approx.hpp"
#include "catch2/catch_test_macros.hpp"
#include "knowhere/comp/time_recorder.h"
#include "knowhere/heap.h"
#include "knowhere/utils.h"
#include "utils.h"

Expand Down Expand Up @@ -76,3 +77,34 @@ TEST_CASE("Test Bitset Generation", "[utils]") {
}
}
}

namespace {
constexpr size_t kHeapSize = 10;
constexpr size_t kElementCount = 10000;
} // namespace

TEST_CASE("ResultMaxHeap") {
knowhere::ResultMaxHeap<float, size_t> heap(kHeapSize);
auto pairs = GenerateRandomDistanceIdPair(kElementCount);
for (const auto& [dist, id] : pairs) {
heap.Push(dist, id);
}
REQUIRE(heap.Size() == kHeapSize);
std::sort(pairs.begin(), pairs.end());
for (int i = kHeapSize - 1; i >= 0; --i) {
auto op = heap.Pop();
REQUIRE(op.has_value());
REQUIRE(op.value().second == pairs[i].second);
}
REQUIRE(heap.Size() == 0);
}

TEST_CASE("Test Time Recorder") {
knowhere::TimeRecorder tr("test", 2);
int64_t sum = 0;
for (int i = 0; i < 10000; i++) {
sum += i * i;
}
auto span = tr.ElapseFromBegin("done");
REQUIRE(span > 0);
}

0 comments on commit 211b909

Please sign in to comment.