Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update raft to 24.10 #914

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/libs/libraft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ add_definitions(-DKNOWHERE_WITH_RAFT)
add_definitions(-DRAFT_EXPLICIT_INSTANTIATE_ONLY)
set(RAFT_VERSION "${RAPIDS_VERSION}")
set(RAFT_FORK "milvus-io")
set(RAFT_PINNED_TAG "branch-24.04")
set(RAFT_PINNED_TAG "branch-24.10")

rapids_find_package(CUDAToolkit REQUIRED BUILD_EXPORT_SET knowhere-exports
INSTALL_EXPORT_SET knowhere-exports)
Expand Down
2 changes: 1 addition & 1 deletion cmake/libs/librapids.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# License for the specific language governing permissions and limitations under
# the License.

set(RAPIDS_VERSION 24.04)
set(RAPIDS_VERSION 24.10)

if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake)
file(
Expand Down
21 changes: 3 additions & 18 deletions src/common/raft/integration/raft_knowhere_index.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -404,28 +404,16 @@ struct raft_knowhere_index<IndexKind>::impl {
}
auto const& res = get_device_resources_without_mempool();
auto host_data = raft::make_host_matrix_view(data, row_count, feature_count);
if constexpr (index_kind == raft_proto::raft_index_kind::ivf_flat) {
if (config.cache_dataset_on_device) {
device_dataset_storage =
raft::make_device_matrix<data_type, input_indexing_type>(res, row_count, feature_count);
auto device_data = device_dataset_storage->view();
raft::copy(res, device_data, host_data);
index_ = raft_index_type::template build<data_type, indexing_type, input_indexing_type>(
res, index_params, raft::make_const_mdspan(device_data));
if (!config.cache_dataset_on_device) {
device_dataset_storage = std::nullopt;
}
} else {
if (config.cache_dataset_on_device) {
device_dataset_storage =
raft::make_device_matrix<data_type, input_indexing_type>(res, row_count, feature_count);
auto device_data = device_dataset_storage->view();
raft::copy(res, device_data, host_data);
index_ = raft_index_type::template build<data_type, indexing_type, input_indexing_type>(
res, index_params, raft::make_const_mdspan(device_data));
} else {
index_ = raft_index_type::template build<data_type, indexing_type, input_indexing_type>(
res, index_params, raft::make_const_mdspan(host_data));
}
index_ = raft_index_type::template build<data_type, indexing_type, input_indexing_type>(
res, index_params, raft::make_const_mdspan(host_data));
}
}

Expand All @@ -447,7 +435,6 @@ struct raft_knowhere_index<IndexKind>::impl {
} else {
if (index_) {
auto const& res = get_device_resources_without_mempool();
raft::resource::set_workspace_to_pool_resource(res);
auto host_data = raft::make_host_matrix_view(data, row_count, feature_count);
device_dataset_storage =
raft::make_device_matrix<data_type, input_indexing_type>(res, row_count, feature_count);
Expand Down Expand Up @@ -488,11 +475,9 @@ struct raft_knowhere_index<IndexKind>::impl {
auto device_data_storage =
raft::make_device_matrix<data_type, input_indexing_type>(res, row_count, feature_count);
raft::copy(res, device_data_storage.view(), host_data);

auto device_bitset =
std::optional<raft::core::bitset<knowhere_bitset_data_type, knowhere_bitset_indexing_type>>{};
auto k_tmp = k;

if (bitset_data != nullptr && bitset_byte_size != 0) {
device_bitset =
raft::core::bitset<knowhere_bitset_data_type, knowhere_bitset_indexing_type>(res, bitset_size);
Expand Down
35 changes: 19 additions & 16 deletions src/common/raft/proto/raft_index.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,22 @@ post_filter(raft::resources const& res, filter_lambda_t const& sample_filter, in
// below, but I am not sure whether or not that would be a net benefit. This
// deserves some benchmarking unless pre-filtering gets in before we revisit
// this.
thrust::for_each(raft::resource::get_thrust_policy(res),
thrust::make_zip_iterator(
thrust::make_tuple(counter, mdspan_begin(index_mdspan), mdspan_begin(distance_mdspan))),
thrust::make_zip_iterator(thrust::make_tuple(
counter + index_mdspan.size(), mdspan_end(index_mdspan), mdspan_end(distance_mdspan))),
[=] __device__(auto& index_id_distance) {
auto index = thrust::get<0>(index_id_distance);
auto& id = thrust::get<1>(index_id_distance);
auto& distance = thrust::get<2>(index_id_distance);
if (!sample_filter(index / index_mdspan.extent(1), id)) {
id = std::numeric_limits<std::remove_reference_t<decltype(id)>>::max();
distance = std::numeric_limits<std::remove_reference_t<decltype(distance)>>::max();
}
});
thrust::for_each(
raft::resource::get_thrust_policy(res),
thrust::make_zip_iterator(
thrust::make_tuple(counter, mdspan_begin(index_mdspan), mdspan_begin(distance_mdspan))),
thrust::make_zip_iterator(
thrust::make_tuple(counter + index_mdspan.size(), mdspan_end(index_mdspan), mdspan_end(distance_mdspan))),
[=] __device__(const thrust::tuple<decltype(index_mdspan.extent(0)), typename index_mdspan_t::element_type&,
typename distance_mdspan_t::element_type&>& index_id_distance) {
auto index = thrust::get<0>(index_id_distance);
auto& id = thrust::get<1>(index_id_distance);
auto& distance = thrust::get<2>(index_id_distance);
if (!sample_filter(index / index_mdspan.extent(1), id)) {
id = std::numeric_limits<std::remove_reference_t<decltype(id)>>::max();
distance = std::numeric_limits<std::remove_reference_t<decltype(distance)>>::max();
}
});
for (auto i = 0; i < index_mdspan.extent(0); ++i) {
auto id_row_begin = mdspan_begin_row(index_mdspan, i);
auto id_row_end = mdspan_end_row(index_mdspan, i);
Expand Down Expand Up @@ -285,8 +287,9 @@ struct raft_index {
} else if constexpr (vector_index_kind == raft_index_kind::ivf_pq) {
return raft_index<underlying_index_type, raft_index_args...>{raft::neighbors::ivf_pq::build<T, IdxT>(
res, index_params, data.data_handle(), data.extent(0), data.extent(1))};
} else {
RAFT_FAIL("IVF flat does not support construction from host data");
} else if constexpr (vector_index_kind == raft_index_kind::ivf_flat) {
return raft_index<underlying_index_type, raft_index_args...>{
raft::neighbors::ivf_flat::build<T, IdxT>(res, index_params, data)};
}
} else {
if constexpr (vector_index_kind == raft_index_kind::brute_force) {
Expand Down
16 changes: 8 additions & 8 deletions tests/ut/test_gpu_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ TEST_CASE("Test All GPU Index", "[search]") {

auto cagra_gen = [base_gen]() {
knowhere::Json json = base_gen();
json[knowhere::indexparam::INTERMEDIATE_GRAPH_DEGREE] = 128;
json[knowhere::indexparam::GRAPH_DEGREE] = 64;
json[knowhere::indexparam::INTERMEDIATE_GRAPH_DEGREE] = 64;
json[knowhere::indexparam::GRAPH_DEGREE] = 32;
json[knowhere::indexparam::ITOPK_SIZE] = 128;
return json;
};
Expand Down Expand Up @@ -100,7 +100,7 @@ TEST_CASE("Test All GPU Index", "[search]") {
auto res = idx.Build(train_ds, json);
REQUIRE(res == knowhere::Status::success);
REQUIRE(idx.HasRawData(json[knowhere::meta::METRIC_TYPE]) ==
IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
auto results = idx.Search(query_ds, json, nullptr);
REQUIRE(results.has_value());
auto ids = results.value()->GetIds();
Expand Down Expand Up @@ -130,7 +130,7 @@ TEST_CASE("Test All GPU Index", "[search]") {
auto res = idx.Build(train_ds, json);
REQUIRE(res == knowhere::Status::success);
REQUIRE(idx.HasRawData(json[knowhere::meta::METRIC_TYPE]) ==
IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
std::vector<std::function<std::vector<uint8_t>(size_t, size_t)>> gen_bitset_funcs = {
GenerateBitsetWithFirstTbitsSet, GenerateBitsetWithRandomTbitsSet};
const auto bitset_percentages = {0.4f, 0.98f};
Expand All @@ -145,7 +145,7 @@ TEST_CASE("Test All GPU Index", "[search]") {
if (percentage == 0.98f) {
REQUIRE(recall > 0.4f);
} else {
REQUIRE(recall > 0.8f);
REQUIRE(recall > 0.7f);
}
}
}
Expand All @@ -171,7 +171,7 @@ TEST_CASE("Test All GPU Index", "[search]") {
auto res = idx.Build(train_ds, json);
REQUIRE(res == knowhere::Status::success);
REQUIRE(idx.HasRawData(json[knowhere::meta::METRIC_TYPE]) ==
IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
const auto topk_values = {// Tuple with [TopKValue, Threshold]
make_tuple(5, 0.85f), make_tuple(25, 0.85f), make_tuple(100, 0.85f)};

Expand Down Expand Up @@ -210,7 +210,7 @@ TEST_CASE("Test All GPU Index", "[search]") {
auto idx_ = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
idx_.Deserialize(bs);
REQUIRE(idx.HasRawData(json[knowhere::meta::METRIC_TYPE]) ==
IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
auto results = idx_.Search(query_ds, json, nullptr);
REQUIRE(results.has_value());
auto ids = results.value()->GetIds();
Expand Down Expand Up @@ -238,7 +238,7 @@ TEST_CASE("Test All GPU Index", "[search]") {
auto res = idx.Build(train_ds, json);
REQUIRE(res == knowhere::Status::success);
REQUIRE(idx.HasRawData(json[knowhere::meta::METRIC_TYPE]) ==
IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
std::vector<uint8_t> bitset_data(2);
bitset_data[0] = 0b10100010;
bitset_data[1] = 0b00100011;
Expand Down
2 changes: 1 addition & 1 deletion tests/ut/test_index_check.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ TEST_CASE("Test index has raw data", "[IndexHasRawData]") {
CHECK_FALSE(knowhere::IndexStaticFaced<fp32>::HasRawData(IndexEnum::INDEX_FAISS_HNSW_PRQ, ver, {}));

// diskann
#ifndef KNOWHERE_WITH_CARDINAL
#ifdef KNOWHERE_WITH_DISKANN
CHECK(knowhere::IndexStaticFaced<fp32>::HasRawData(IndexEnum::INDEX_DISKANN, ver,
knowhere::Json::parse(R"({"metric_type": "L2"})")));
CHECK_FALSE(knowhere::IndexStaticFaced<fp32>::HasRawData(IndexEnum::INDEX_DISKANN, ver,
Expand Down
Loading