Skip to content

Commit

Permalink
Fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Apr 16, 2021
1 parent 482939a commit 4e15c6e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1392,8 +1392,8 @@ void SegmentedArgSort(xgboost::common::Span<U> values,
TemporaryArray<xgboost::common::byte> temp_storage(bytes);
detail::DeviceSegmentedRadixSortPair<!accending>(
temp_storage.data().get(), bytes, values.data(), values_out.data().get(),
sorted_idx_out.data().get(), sorted_idx.data(), sorted_idx.size(), n_groups,
group_ptr.data(), group_ptr.data() + 1);
sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(),
n_groups, group_ptr.data(), group_ptr.data() + 1);

safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice));
Expand Down
2 changes: 1 addition & 1 deletion src/metric/auc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
});

// unique values are sparse, so we need a CSR style indptr
dh::TemporaryArray<uint32_t> unique_class_ptr(class_ptr.size() + 1);
dh::TemporaryArray<uint32_t> unique_class_ptr(class_ptr.size());
auto d_unique_class_ptr = dh::ToSpan(unique_class_ptr);
auto n_uniques = dh::SegmentedUniqueByKey(
thrust::cuda::par(alloc),
Expand Down

0 comments on commit 4e15c6e

Please sign in to comment.