Skip to content

Commit

Permalink
Fix compilation with the latest ctk.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Mar 14, 2024
1 parent e0f890b commit c760f85
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
6 changes: 3 additions & 3 deletions src/common/quantile.cu
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,13 @@ common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(

auto scan_key_it = dh::MakeTransformIterator<size_t>(
thrust::make_counting_iterator(0ul),
[=] __device__(size_t idx) { return dh::SegmentId(out_ptr, idx); });
[=] XGBOOST_DEVICE(size_t idx) { return dh::SegmentId(out_ptr, idx); });

auto scan_val_it = dh::MakeTransformIterator<Tuple>(
merge_path.data(), [=] __device__(Tuple const &t) -> Tuple {
merge_path.data(), [=] XGBOOST_DEVICE(Tuple const &t) -> Tuple {
auto ind = get_ind(t); // == 0 if element is from x
// x_counter, y_counter
return thrust::make_tuple<uint64_t, uint64_t>(!ind, ind);
return thrust::tuple<std::uint64_t, std::uint64_t>{!ind, ind};
});

// Compute the index for both x and y (which of the element in a and b are used in each
Expand Down
9 changes: 4 additions & 5 deletions src/data/ellpack_page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,10 @@ struct WriteCompressedEllpackFunctor {

using Tuple = thrust::tuple<size_t, size_t, size_t>;
__device__ size_t operator()(Tuple out) {
auto e = batch.GetElement(out.get<2>());
auto e = batch.GetElement(thrust::get<2>(out));
if (is_valid(e)) {
// -1 because the scan is inclusive
size_t output_position =
accessor.row_stride * e.row_idx + out.get<1>() - 1;
size_t output_position = accessor.row_stride * e.row_idx + thrust::get<1>(out) - 1;
uint32_t bin_idx = 0;
if (common::IsCat(feature_types, e.column_idx)) {
bin_idx = accessor.SearchBin<true>(e.value, e.column_idx);
Expand All @@ -192,8 +191,8 @@ template <typename Tuple>
struct TupleScanOp {
__device__ Tuple operator()(Tuple a, Tuple b) {
// Key equal
if (a.template get<0>() == b.template get<0>()) {
b.template get<1>() += a.template get<1>();
if (thrust::get<0>(a) == thrust::get<0>(b)) {
thrust::get<1>(b) += thrust::get<1>(a);
return b;
}
// Not equal
Expand Down

0 comments on commit c760f85

Please sign in to comment.