diff --git a/cpp/src/metrics/pairwise_distance.cu b/cpp/src/metrics/pairwise_distance.cu index 1928091c33..7a05e9b63d 100644 --- a/cpp/src/metrics/pairwise_distance.cu +++ b/cpp/src/metrics/pairwise_distance.cu @@ -1,6 +1,6 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,7 +31,6 @@ #include #include #include -#include #include namespace ML { @@ -164,23 +163,18 @@ void pairwiseDistance_sparse(const raft::handle_t& handle, raft::distance::DistanceType metric, float metric_arg) { - raft::sparse::distance::distances_config_t dist_config(handle); + auto out = raft::make_device_matrix_view(dist, y_nrows, x_nrows); - dist_config.b_nrows = x_nrows; - dist_config.b_ncols = n_cols; - dist_config.b_nnz = x_nnz; - dist_config.b_indptr = x_indptr; - dist_config.b_indices = x_indices; - dist_config.b_data = x; + auto x_structure = raft::make_device_compressed_structure_view( + x_indptr, x_indices, x_nrows, n_cols, x_nnz); + auto x_csr_view = raft::make_device_csr_matrix_view(x, x_structure); - dist_config.a_nrows = y_nrows; - dist_config.a_ncols = n_cols; - dist_config.a_nnz = y_nnz; - dist_config.a_indptr = y_indptr; - dist_config.a_indices = y_indices; - dist_config.a_data = y; + auto y_structure = raft::make_device_compressed_structure_view( + y_indptr, y_indices, y_nrows, n_cols, y_nnz); + auto y_csr_view = raft::make_device_csr_matrix_view(y, y_structure); - raft::sparse::distance::pairwiseDistance(dist, dist_config, metric, metric_arg); + raft::sparse::distance::pairwise_distance( + handle, y_csr_view, x_csr_view, out, metric, metric_arg); } void pairwiseDistance_sparse(const raft::handle_t& handle,