diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index a66711a784f6..a6db19a6659a 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -1321,15 +1321,16 @@ void ArgSort(xgboost::common::Span keys, xgboost::common::Span sorted_i TemporaryArray out(keys.size()); cub::DoubleBuffer d_keys(const_cast(keys.data()), out.data().get()); + TemporaryArray sorted_idx_out(sorted_idx.size()); cub::DoubleBuffer d_values(const_cast(sorted_idx.data()), - sorted_idx.data()); + sorted_idx_out.data().get()); if (accending) { void *d_temp_storage = nullptr; safe_cuda((cub::DispatchRadixSort::Dispatch( d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false, nullptr, false))); - dh::TemporaryArray storage(bytes); + TemporaryArray storage(bytes); d_temp_storage = storage.data().get(); safe_cuda((cub::DispatchRadixSort::Dispatch( d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, @@ -1339,12 +1340,15 @@ void ArgSort(xgboost::common::Span keys, xgboost::common::Span sorted_i safe_cuda((cub::DispatchRadixSort::Dispatch( d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false, nullptr, false))); - dh::TemporaryArray storage(bytes); + TemporaryArray storage(bytes); d_temp_storage = storage.data().get(); safe_cuda((cub::DispatchRadixSort::Dispatch( d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false, nullptr, false))); } + + safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(), + sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice)); } namespace detail { @@ -1379,14 +1383,19 @@ void SegmentedArgSort(xgboost::common::Span values, size_t bytes = 0; Iota(sorted_idx); TemporaryArray> values_out(values.size()); + TemporaryArray> sorted_idx_out(sorted_idx.size()); + detail::DeviceSegmentedRadixSortPair( nullptr, bytes, values.data(), values_out.data().get(), sorted_idx.data(), - sorted_idx.data(), sorted_idx.size(), n_groups, group_ptr.data(), + sorted_idx_out.data().get(), sorted_idx.size(), n_groups, group_ptr.data(), group_ptr.data() + 1); - dh::TemporaryArray temp_storage(bytes); + TemporaryArray temp_storage(bytes); detail::DeviceSegmentedRadixSortPair( temp_storage.data().get(), bytes, values.data(), values_out.data().get(), - sorted_idx.data(), sorted_idx.data(), sorted_idx.size(), n_groups, - group_ptr.data(), group_ptr.data() + 1); + sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), + n_groups, group_ptr.data(), group_ptr.data() + 1); + + safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(), + sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice)); } } // namespace dh diff --git a/src/metric/auc.cu b/src/metric/auc.cu index 433d84710839..ea837e413590 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -269,7 +269,7 @@ float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info }); // unique values are sparse, so we need a CSR style indptr - dh::TemporaryArray unique_class_ptr(class_ptr.size() + 1); + dh::TemporaryArray unique_class_ptr(class_ptr.size()); auto d_unique_class_ptr = dh::ToSpan(unique_class_ptr); auto n_uniques = dh::SegmentedUniqueByKey( thrust::cuda::par(alloc),