From 91a9e08aa89b71e4f4beed9fab4f6e88c0639b7b Mon Sep 17 00:00:00 2001 From: "yusheng.ma" Date: Tue, 29 Oct 2024 20:08:31 +0800 Subject: [PATCH] update raft to 24.10 Signed-off-by: yusheng.ma --- cmake/libs/libraft.cmake | 2 +- cmake/libs/librapids.cmake | 2 +- .../raft/integration/raft_knowhere_index.cuh | 21 ++--------- src/common/raft/proto/raft_index.cuh | 35 ++++++++++--------- tests/ut/test_gpu_search.cc | 16 ++++----- tests/ut/test_index_check.cc | 2 +- 6 files changed, 33 insertions(+), 45 deletions(-) diff --git a/cmake/libs/libraft.cmake b/cmake/libs/libraft.cmake index dd4506f56..b2894f282 100644 --- a/cmake/libs/libraft.cmake +++ b/cmake/libs/libraft.cmake @@ -17,7 +17,7 @@ add_definitions(-DKNOWHERE_WITH_RAFT) add_definitions(-DRAFT_EXPLICIT_INSTANTIATE_ONLY) set(RAFT_VERSION "${RAPIDS_VERSION}") set(RAFT_FORK "milvus-io") -set(RAFT_PINNED_TAG "branch-24.04") +set(RAFT_PINNED_TAG "branch-24.10") rapids_find_package(CUDAToolkit REQUIRED BUILD_EXPORT_SET knowhere-exports INSTALL_EXPORT_SET knowhere-exports) diff --git a/cmake/libs/librapids.cmake b/cmake/libs/librapids.cmake index a8b4e6c8c..3d5ceec92 100644 --- a/cmake/libs/librapids.cmake +++ b/cmake/libs/librapids.cmake @@ -13,7 +13,7 @@ # License for the specific language governing permissions and limitations under # the License. -set(RAPIDS_VERSION 24.04) +set(RAPIDS_VERSION 24.10) if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) file( diff --git a/src/common/raft/integration/raft_knowhere_index.cuh b/src/common/raft/integration/raft_knowhere_index.cuh index 236c32bee..41c7480e5 100644 --- a/src/common/raft/integration/raft_knowhere_index.cuh +++ b/src/common/raft/integration/raft_knowhere_index.cuh @@ -404,28 +404,16 @@ struct raft_knowhere_index::impl { } auto const& res = get_device_resources_without_mempool(); auto host_data = raft::make_host_matrix_view(data, row_count, feature_count); - if constexpr (index_kind == raft_proto::raft_index_kind::ivf_flat) { + if (config.cache_dataset_on_device) { device_dataset_storage = raft::make_device_matrix(res, row_count, feature_count); auto device_data = device_dataset_storage->view(); raft::copy(res, device_data, host_data); index_ = raft_index_type::template build( res, index_params, raft::make_const_mdspan(device_data)); - if (!config.cache_dataset_on_device) { - device_dataset_storage = std::nullopt; - } } else { - if (config.cache_dataset_on_device) { - device_dataset_storage = - raft::make_device_matrix(res, row_count, feature_count); - auto device_data = device_dataset_storage->view(); - raft::copy(res, device_data, host_data); - index_ = raft_index_type::template build( - res, index_params, raft::make_const_mdspan(device_data)); - } else { - index_ = raft_index_type::template build( - res, index_params, raft::make_const_mdspan(host_data)); - } + index_ = raft_index_type::template build( + res, index_params, raft::make_const_mdspan(host_data)); } } @@ -447,7 +435,6 @@ struct raft_knowhere_index::impl { } else { if (index_) { auto const& res = get_device_resources_without_mempool(); - raft::resource::set_workspace_to_pool_resource(res); auto host_data = raft::make_host_matrix_view(data, row_count, feature_count); device_dataset_storage = raft::make_device_matrix(res, row_count, feature_count); @@ -488,11 +475,9 @@ struct raft_knowhere_index::impl { auto device_data_storage = raft::make_device_matrix(res, row_count, feature_count); raft::copy(res, device_data_storage.view(), host_data); - auto device_bitset = std::optional>{}; auto k_tmp = k; - if (bitset_data != nullptr && bitset_byte_size != 0) { device_bitset = raft::core::bitset(res, bitset_size); diff --git a/src/common/raft/proto/raft_index.cuh b/src/common/raft/proto/raft_index.cuh index c491c466e..efa631933 100644 --- a/src/common/raft/proto/raft_index.cuh +++ b/src/common/raft/proto/raft_index.cuh @@ -102,20 +102,22 @@ post_filter(raft::resources const& res, filter_lambda_t const& sample_filter, in // below, but I am not sure whether or not that would be a net benefit. This // deserves some benchmarking unless pre-filtering gets in before we revisit // this. - thrust::for_each(raft::resource::get_thrust_policy(res), - thrust::make_zip_iterator( - thrust::make_tuple(counter, mdspan_begin(index_mdspan), mdspan_begin(distance_mdspan))), - thrust::make_zip_iterator(thrust::make_tuple( - counter + index_mdspan.size(), mdspan_end(index_mdspan), mdspan_end(distance_mdspan))), - [=] __device__(auto& index_id_distance) { - auto index = thrust::get<0>(index_id_distance); - auto& id = thrust::get<1>(index_id_distance); - auto& distance = thrust::get<2>(index_id_distance); - if (!sample_filter(index / index_mdspan.extent(1), id)) { - id = std::numeric_limits>::max(); - distance = std::numeric_limits>::max(); - } - }); + thrust::for_each( + raft::resource::get_thrust_policy(res), + thrust::make_zip_iterator( + thrust::make_tuple(counter, mdspan_begin(index_mdspan), mdspan_begin(distance_mdspan))), + thrust::make_zip_iterator( + thrust::make_tuple(counter + index_mdspan.size(), mdspan_end(index_mdspan), mdspan_end(distance_mdspan))), + [=] __device__(const thrust::tuple& index_id_distance) { + auto index = thrust::get<0>(index_id_distance); + auto& id = thrust::get<1>(index_id_distance); + auto& distance = thrust::get<2>(index_id_distance); + if (!sample_filter(index / index_mdspan.extent(1), id)) { + id = std::numeric_limits>::max(); + distance = std::numeric_limits>::max(); + } + }); for (auto i = 0; i < index_mdspan.extent(0); ++i) { auto id_row_begin = mdspan_begin_row(index_mdspan, i); auto id_row_end = mdspan_end_row(index_mdspan, i); @@ -285,8 +287,9 @@ struct raft_index { } else if constexpr (vector_index_kind == raft_index_kind::ivf_pq) { return raft_index{raft::neighbors::ivf_pq::build( res, index_params, data.data_handle(), data.extent(0), data.extent(1))}; - } else { - RAFT_FAIL("IVF flat does not support construction from host data"); + } else if constexpr (vector_index_kind == raft_index_kind::ivf_flat) { + return raft_index{ + raft::neighbors::ivf_flat::build(res, index_params, data)}; } } else { if constexpr (vector_index_kind == raft_index_kind::brute_force) { diff --git a/tests/ut/test_gpu_search.cc b/tests/ut/test_gpu_search.cc index b7024f0b1..3650fec9b 100644 --- a/tests/ut/test_gpu_search.cc +++ b/tests/ut/test_gpu_search.cc @@ -59,8 +59,8 @@ TEST_CASE("Test All GPU Index", "[search]") { auto cagra_gen = [base_gen]() { knowhere::Json json = base_gen(); - json[knowhere::indexparam::INTERMEDIATE_GRAPH_DEGREE] = 128; - json[knowhere::indexparam::GRAPH_DEGREE] = 64; + json[knowhere::indexparam::INTERMEDIATE_GRAPH_DEGREE] = 64; + json[knowhere::indexparam::GRAPH_DEGREE] = 32; json[knowhere::indexparam::ITOPK_SIZE] = 128; return json; }; @@ -100,7 +100,7 @@ TEST_CASE("Test All GPU Index", "[search]") { auto res = idx.Build(train_ds, json); REQUIRE(res == knowhere::Status::success); REQUIRE(idx.HasRawData(json[knowhere::meta::METRIC_TYPE]) == - IndexStaticFaced::HasRawData(name, version, json)); + knowhere::IndexStaticFaced::HasRawData(name, version, json)); auto results = idx.Search(query_ds, json, nullptr); REQUIRE(results.has_value()); auto ids = results.value()->GetIds(); @@ -130,7 +130,7 @@ TEST_CASE("Test All GPU Index", "[search]") { auto res = idx.Build(train_ds, json); REQUIRE(res == knowhere::Status::success); REQUIRE(idx.HasRawData(json[knowhere::meta::METRIC_TYPE]) == - IndexStaticFaced::HasRawData(name, version, json)); + knowhere::IndexStaticFaced::HasRawData(name, version, json)); std::vector(size_t, size_t)>> gen_bitset_funcs = { GenerateBitsetWithFirstTbitsSet, GenerateBitsetWithRandomTbitsSet}; const auto bitset_percentages = {0.4f, 0.98f}; @@ -145,7 +145,7 @@ TEST_CASE("Test All GPU Index", "[search]") { if (percentage == 0.98f) { REQUIRE(recall > 0.4f); } else { - REQUIRE(recall > 0.8f); + REQUIRE(recall > 0.7f); } } } @@ -171,7 +171,7 @@ TEST_CASE("Test All GPU Index", "[search]") { auto res = idx.Build(train_ds, json); REQUIRE(res == knowhere::Status::success); REQUIRE(idx.HasRawData(json[knowhere::meta::METRIC_TYPE]) == - IndexStaticFaced::HasRawData(name, version, json)); + knowhere::IndexStaticFaced::HasRawData(name, version, json)); const auto topk_values = {// Tuple with [TopKValue, Threshold] make_tuple(5, 0.85f), make_tuple(25, 0.85f), make_tuple(100, 0.85f)}; @@ -210,7 +210,7 @@ TEST_CASE("Test All GPU Index", "[search]") { auto idx_ = knowhere::IndexFactory::Instance().Create(name, version).value(); idx_.Deserialize(bs); REQUIRE(idx.HasRawData(json[knowhere::meta::METRIC_TYPE]) == - IndexStaticFaced::HasRawData(name, version, json)); + knowhere::IndexStaticFaced::HasRawData(name, version, json)); auto results = idx_.Search(query_ds, json, nullptr); REQUIRE(results.has_value()); auto ids = results.value()->GetIds(); @@ -238,7 +238,7 @@ TEST_CASE("Test All GPU Index", "[search]") { auto res = idx.Build(train_ds, json); REQUIRE(res == knowhere::Status::success); REQUIRE(idx.HasRawData(json[knowhere::meta::METRIC_TYPE]) == - IndexStaticFaced::HasRawData(name, version, json)); + knowhere::IndexStaticFaced::HasRawData(name, version, json)); std::vector bitset_data(2); bitset_data[0] = 0b10100010; bitset_data[1] = 0b00100011; diff --git a/tests/ut/test_index_check.cc b/tests/ut/test_index_check.cc index fa4fedbac..fb658fcca 100644 --- a/tests/ut/test_index_check.cc +++ b/tests/ut/test_index_check.cc @@ -192,7 +192,7 @@ TEST_CASE("Test index has raw data", "[IndexHasRawData]") { CHECK_FALSE(knowhere::IndexStaticFaced::HasRawData(IndexEnum::INDEX_FAISS_HNSW_PRQ, ver, {})); // diskann -#ifndef KNOWHERE_WITH_CARDINAL +#ifdef KNOWHERE_WITH_DISKANN CHECK(knowhere::IndexStaticFaced::HasRawData(IndexEnum::INDEX_DISKANN, ver, knowhere::Json::parse(R"({"metric_type": "L2"})"))); CHECK_FALSE(knowhere::IndexStaticFaced::HasRawData(IndexEnum::INDEX_DISKANN, ver,