Skip to content

Commit

Permalink
update raft to 24.10 (#914)
Browse files Browse the repository at this point in the history
Signed-off-by: yusheng.ma <yusheng.ma@zilliz.com>
  • Loading branch information
Presburger authored Oct 30, 2024
1 parent aa01b18 commit 8e0150b
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 45 deletions.
2 changes: 1 addition & 1 deletion cmake/libs/libraft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ add_definitions(-DKNOWHERE_WITH_RAFT)
add_definitions(-DRAFT_EXPLICIT_INSTANTIATE_ONLY)
set(RAFT_VERSION "${RAPIDS_VERSION}")
set(RAFT_FORK "milvus-io")
set(RAFT_PINNED_TAG "branch-24.04")
set(RAFT_PINNED_TAG "branch-24.10")

rapids_find_package(CUDAToolkit REQUIRED BUILD_EXPORT_SET knowhere-exports
INSTALL_EXPORT_SET knowhere-exports)
Expand Down
2 changes: 1 addition & 1 deletion cmake/libs/librapids.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# License for the specific language governing permissions and limitations under
# the License.

set(RAPIDS_VERSION 24.04)
set(RAPIDS_VERSION 24.10)

if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake)
file(
Expand Down
21 changes: 3 additions & 18 deletions src/common/raft/integration/raft_knowhere_index.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -404,28 +404,16 @@ struct raft_knowhere_index<IndexKind>::impl {
}
auto const& res = get_device_resources_without_mempool();
auto host_data = raft::make_host_matrix_view(data, row_count, feature_count);
if constexpr (index_kind == raft_proto::raft_index_kind::ivf_flat) {
if (config.cache_dataset_on_device) {
device_dataset_storage =
raft::make_device_matrix<data_type, input_indexing_type>(res, row_count, feature_count);
auto device_data = device_dataset_storage->view();
raft::copy(res, device_data, host_data);
index_ = raft_index_type::template build<data_type, indexing_type, input_indexing_type>(
res, index_params, raft::make_const_mdspan(device_data));
if (!config.cache_dataset_on_device) {
device_dataset_storage = std::nullopt;
}
} else {
if (config.cache_dataset_on_device) {
device_dataset_storage =
raft::make_device_matrix<data_type, input_indexing_type>(res, row_count, feature_count);
auto device_data = device_dataset_storage->view();
raft::copy(res, device_data, host_data);
index_ = raft_index_type::template build<data_type, indexing_type, input_indexing_type>(
res, index_params, raft::make_const_mdspan(device_data));
} else {
index_ = raft_index_type::template build<data_type, indexing_type, input_indexing_type>(
res, index_params, raft::make_const_mdspan(host_data));
}
index_ = raft_index_type::template build<data_type, indexing_type, input_indexing_type>(
res, index_params, raft::make_const_mdspan(host_data));
}
}

Expand All @@ -447,7 +435,6 @@ struct raft_knowhere_index<IndexKind>::impl {
} else {
if (index_) {
auto const& res = get_device_resources_without_mempool();
raft::resource::set_workspace_to_pool_resource(res);
auto host_data = raft::make_host_matrix_view(data, row_count, feature_count);
device_dataset_storage =
raft::make_device_matrix<data_type, input_indexing_type>(res, row_count, feature_count);
Expand Down Expand Up @@ -488,11 +475,9 @@ struct raft_knowhere_index<IndexKind>::impl {
auto device_data_storage =
raft::make_device_matrix<data_type, input_indexing_type>(res, row_count, feature_count);
raft::copy(res, device_data_storage.view(), host_data);

auto device_bitset =
std::optional<raft::core::bitset<knowhere_bitset_data_type, knowhere_bitset_indexing_type>>{};
auto k_tmp = k;

if (bitset_data != nullptr && bitset_byte_size != 0) {
device_bitset =
raft::core::bitset<knowhere_bitset_data_type, knowhere_bitset_indexing_type>(res, bitset_size);
Expand Down
35 changes: 19 additions & 16 deletions src/common/raft/proto/raft_index.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,22 @@ post_filter(raft::resources const& res, filter_lambda_t const& sample_filter, in
// below, but I am not sure whether or not that would be a net benefit. This
// deserves some benchmarking unless pre-filtering gets in before we revisit
// this.
thrust::for_each(raft::resource::get_thrust_policy(res),
thrust::make_zip_iterator(
thrust::make_tuple(counter, mdspan_begin(index_mdspan), mdspan_begin(distance_mdspan))),
thrust::make_zip_iterator(thrust::make_tuple(
counter + index_mdspan.size(), mdspan_end(index_mdspan), mdspan_end(distance_mdspan))),
[=] __device__(auto& index_id_distance) {
auto index = thrust::get<0>(index_id_distance);
auto& id = thrust::get<1>(index_id_distance);
auto& distance = thrust::get<2>(index_id_distance);
if (!sample_filter(index / index_mdspan.extent(1), id)) {
id = std::numeric_limits<std::remove_reference_t<decltype(id)>>::max();
distance = std::numeric_limits<std::remove_reference_t<decltype(distance)>>::max();
}
});
thrust::for_each(
raft::resource::get_thrust_policy(res),
thrust::make_zip_iterator(
thrust::make_tuple(counter, mdspan_begin(index_mdspan), mdspan_begin(distance_mdspan))),
thrust::make_zip_iterator(
thrust::make_tuple(counter + index_mdspan.size(), mdspan_end(index_mdspan), mdspan_end(distance_mdspan))),
[=] __device__(const thrust::tuple<decltype(index_mdspan.extent(0)), typename index_mdspan_t::element_type&,
typename distance_mdspan_t::element_type&>& index_id_distance) {
auto index = thrust::get<0>(index_id_distance);
auto& id = thrust::get<1>(index_id_distance);
auto& distance = thrust::get<2>(index_id_distance);
if (!sample_filter(index / index_mdspan.extent(1), id)) {
id = std::numeric_limits<std::remove_reference_t<decltype(id)>>::max();
distance = std::numeric_limits<std::remove_reference_t<decltype(distance)>>::max();
}
});
for (auto i = 0; i < index_mdspan.extent(0); ++i) {
auto id_row_begin = mdspan_begin_row(index_mdspan, i);
auto id_row_end = mdspan_end_row(index_mdspan, i);
Expand Down Expand Up @@ -285,8 +287,9 @@ struct raft_index {
} else if constexpr (vector_index_kind == raft_index_kind::ivf_pq) {
return raft_index<underlying_index_type, raft_index_args...>{raft::neighbors::ivf_pq::build<T, IdxT>(
res, index_params, data.data_handle(), data.extent(0), data.extent(1))};
} else {
RAFT_FAIL("IVF flat does not support construction from host data");
} else if constexpr (vector_index_kind == raft_index_kind::ivf_flat) {
return raft_index<underlying_index_type, raft_index_args...>{
raft::neighbors::ivf_flat::build<T, IdxT>(res, index_params, data)};
}
} else {
if constexpr (vector_index_kind == raft_index_kind::brute_force) {
Expand Down
16 changes: 8 additions & 8 deletions tests/ut/test_gpu_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ TEST_CASE("Test All GPU Index", "[search]") {

auto cagra_gen = [base_gen]() {
knowhere::Json json = base_gen();
json[knowhere::indexparam::INTERMEDIATE_GRAPH_DEGREE] = 128;
json[knowhere::indexparam::GRAPH_DEGREE] = 64;
json[knowhere::indexparam::INTERMEDIATE_GRAPH_DEGREE] = 64;
json[knowhere::indexparam::GRAPH_DEGREE] = 32;
json[knowhere::indexparam::ITOPK_SIZE] = 128;
return json;
};
Expand Down Expand Up @@ -100,7 +100,7 @@ TEST_CASE("Test All GPU Index", "[search]") {
auto res = idx.Build(train_ds, json);
REQUIRE(res == knowhere::Status::success);
REQUIRE(idx.HasRawData(json[knowhere::meta::METRIC_TYPE]) ==
IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
auto results = idx.Search(query_ds, json, nullptr);
REQUIRE(results.has_value());
auto ids = results.value()->GetIds();
Expand Down Expand Up @@ -130,7 +130,7 @@ TEST_CASE("Test All GPU Index", "[search]") {
auto res = idx.Build(train_ds, json);
REQUIRE(res == knowhere::Status::success);
REQUIRE(idx.HasRawData(json[knowhere::meta::METRIC_TYPE]) ==
IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
std::vector<std::function<std::vector<uint8_t>(size_t, size_t)>> gen_bitset_funcs = {
GenerateBitsetWithFirstTbitsSet, GenerateBitsetWithRandomTbitsSet};
const auto bitset_percentages = {0.4f, 0.98f};
Expand All @@ -145,7 +145,7 @@ TEST_CASE("Test All GPU Index", "[search]") {
if (percentage == 0.98f) {
REQUIRE(recall > 0.4f);
} else {
REQUIRE(recall > 0.8f);
REQUIRE(recall > 0.7f);
}
}
}
Expand All @@ -171,7 +171,7 @@ TEST_CASE("Test All GPU Index", "[search]") {
auto res = idx.Build(train_ds, json);
REQUIRE(res == knowhere::Status::success);
REQUIRE(idx.HasRawData(json[knowhere::meta::METRIC_TYPE]) ==
IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
const auto topk_values = {// Tuple with [TopKValue, Threshold]
make_tuple(5, 0.85f), make_tuple(25, 0.85f), make_tuple(100, 0.85f)};

Expand Down Expand Up @@ -210,7 +210,7 @@ TEST_CASE("Test All GPU Index", "[search]") {
auto idx_ = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
idx_.Deserialize(bs);
REQUIRE(idx.HasRawData(json[knowhere::meta::METRIC_TYPE]) ==
IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
auto results = idx_.Search(query_ds, json, nullptr);
REQUIRE(results.has_value());
auto ids = results.value()->GetIds();
Expand Down Expand Up @@ -238,7 +238,7 @@ TEST_CASE("Test All GPU Index", "[search]") {
auto res = idx.Build(train_ds, json);
REQUIRE(res == knowhere::Status::success);
REQUIRE(idx.HasRawData(json[knowhere::meta::METRIC_TYPE]) ==
IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
std::vector<uint8_t> bitset_data(2);
bitset_data[0] = 0b10100010;
bitset_data[1] = 0b00100011;
Expand Down
2 changes: 1 addition & 1 deletion tests/ut/test_index_check.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ TEST_CASE("Test index has raw data", "[IndexHasRawData]") {
CHECK_FALSE(knowhere::IndexStaticFaced<fp32>::HasRawData(IndexEnum::INDEX_FAISS_HNSW_PRQ, ver, {}));

// diskann
#ifndef KNOWHERE_WITH_CARDINAL
#ifdef KNOWHERE_WITH_DISKANN
CHECK(knowhere::IndexStaticFaced<fp32>::HasRawData(IndexEnum::INDEX_DISKANN, ver,
knowhere::Json::parse(R"({"metric_type": "L2"})")));
CHECK_FALSE(knowhere::IndexStaticFaced<fp32>::HasRawData(IndexEnum::INDEX_DISKANN, ver,
Expand Down

0 comments on commit 8e0150b

Please sign in to comment.