From 8f8bd8147adbe009a13c2929d1e032842c627192 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 9 Jun 2022 01:33:41 +0800 Subject: [PATCH] Fix LTR with weighted Quantile DMatrix. (#7975) * Fix LTR with weighted Quantile DMatrix. * Better tests. --- src/common/hist_util.cuh | 12 +++++----- src/common/quantile.cc | 35 +++++------------------------- src/common/quantile.h | 23 ++++++++++++++++++++ tests/cpp/common/test_hist_util.cc | 12 ++++++++++ tests/cpp/common/test_hist_util.cu | 21 +++++++++++++++++- tests/cpp/common/test_hist_util.h | 22 +++++++++++++------ 6 files changed, 83 insertions(+), 42 deletions(-) diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index 419febc22bc9..8fac9fca2f14 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -184,8 +184,6 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, dh::safe_cuda(cudaSetDevice(device)); info.weights_.SetDevice(device); auto weights = info.weights_.ConstDeviceSpan(); - dh::caching_device_vector group_ptr(info.group_ptr_); - auto d_group_ptr = dh::ToSpan(group_ptr); auto batch_iter = dh::MakeTransformIterator( thrust::make_counting_iterator(0llu), @@ -205,9 +203,13 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, auto d_temp_weights = dh::ToSpan(temp_weights); if (is_ranking) { + if (!weights.empty()) { + CHECK_EQ(weights.size(), info.group_ptr_.size() - 1); + } + dh::caching_device_vector group_ptr(info.group_ptr_); + auto d_group_ptr = dh::ToSpan(group_ptr); auto const weight_iter = dh::MakeTransformIterator( - thrust::make_constant_iterator(0lu), - [=]__device__(size_t idx) -> float { + thrust::make_counting_iterator(0lu), [=] __device__(size_t idx) -> float { auto ridx = batch.GetElement(idx).row_idx; bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx); return weights[group_idx]; @@ -272,7 +274,7 @@ void AdapterDeviceSketch(Batch batch, int num_bins, size_t num_cols = batch.NumCols(); size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows); int32_t device = sketch_container->DeviceIdx(); - bool weighted = info.weights_.Size() != 0; + bool weighted = !info.weights_.Empty(); if (weighted) { sketch_batch_num_elements = detail::SketchBatchNumElements( diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 78dcfb3c2aa5..42fe719708c7 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -122,27 +122,6 @@ std::vector MergeWeights(MetaInfo const &info, Span hessian, } return results; } - -std::vector UnrollGroupWeights(MetaInfo const &info) { - std::vector const &group_weights = info.weights_.HostVector(); - if (group_weights.empty()) { - return group_weights; - } - - size_t n_samples = info.num_row_; - auto const &group_ptr = info.group_ptr_; - std::vector results(n_samples); - CHECK_GE(group_ptr.size(), 2); - CHECK_EQ(group_ptr.back(), n_samples); - size_t cur_group = 0; - for (size_t i = 0; i < n_samples; ++i) { - results[i] = group_weights[cur_group]; - if (i == group_ptr[cur_group + 1]) { - cur_group++; - } - } - return results; -} } // anonymous namespace template @@ -156,12 +135,10 @@ void SketchContainerImpl::PushRowPage(SparsePage const &page, MetaInfo // glue these conditions using ternary operator to avoid making data copies. auto const &weights = - hessian.empty() - ? (use_group_ind_ ? UnrollGroupWeights(info) // use group weight - : info.weights_.HostVector()) // use sample weight - : MergeWeights( - info, hessian, use_group_ind_, - n_threads_); // use hessian merged with group/sample weights + hessian.empty() ? (use_group_ind_ ? detail::UnrollGroupWeights(info) // use group weight + : info.weights_.HostVector()) // use sample weight + : MergeWeights(info, hessian, use_group_ind_, + n_threads_); // use hessian merged with group/sample weights if (!weights.empty()) { CHECK_EQ(weights.size(), info.num_row_); } @@ -563,8 +540,8 @@ void SortedSketchContainer::PushColPage(SparsePage const &page, MetaInfo const & monitor_.Start(__func__); // glue these conditions using ternary operator to avoid making data copies. auto const &weights = - hessian.empty() ? (use_group_ind_ ? UnrollGroupWeights(info) // use group weight - : info.weights_.HostVector()) // use sample weight + hessian.empty() ? (use_group_ind_ ? detail::UnrollGroupWeights(info) // use group weight + : info.weights_.HostVector()) // use sample weight : MergeWeights(info, hessian, use_group_ind_, n_threads_); // use hessian merged with group/sample weights CHECK_EQ(weights.size(), info.num_row_); diff --git a/src/common/quantile.h b/src/common/quantile.h index d7c65c8c06eb..7f08be442fd2 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -697,6 +697,29 @@ class WXQuantileSketch : public QuantileSketchTemplate > { }; +namespace detail { +inline std::vector UnrollGroupWeights(MetaInfo const &info) { + std::vector const &group_weights = info.weights_.HostVector(); + if (group_weights.empty()) { + return group_weights; + } + + size_t n_samples = info.num_row_; + auto const &group_ptr = info.group_ptr_; + std::vector results(n_samples); + CHECK_GE(group_ptr.size(), 2); + CHECK_EQ(group_ptr.back(), n_samples); + size_t cur_group = 0; + for (size_t i = 0; i < n_samples; ++i) { + results[i] = group_weights[cur_group]; + if (i == group_ptr[cur_group + 1]) { + cur_group++; + } + } + return results; +} +} // namespace detail + class HistogramCuts; /*! diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 3dd33e03a316..418caab134c1 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -366,6 +366,7 @@ void TestSketchFromWeights(bool with_group) { ValidateCuts(cuts, m.get(), kBins); if (with_group) { + m->Info().weights_ = decltype(m->Info().weights_)(); // remove weight HistogramCuts non_weighted = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0)); for (size_t i = 0; i < cuts.Values().size(); ++i) { EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]); @@ -377,6 +378,17 @@ void TestSketchFromWeights(bool with_group) { ASSERT_EQ(cuts.Ptrs().at(i), non_weighted.Ptrs().at(i)); } } + + if (with_group) { + auto& h_weights = info.weights_.HostVector(); + h_weights.resize(kGroups); + // Generate different weight. + for (size_t i = 0; i < h_weights.size(); ++i) { + h_weights[i] = static_cast(i + 1) / static_cast(kGroups); + } + HistogramCuts weighted = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0)); + ValidateCuts(weighted, m.get(), kBins); + } } TEST(HistUtil, SketchFromWeights) { diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index f02bff547c5a..612f84840b7f 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -593,9 +593,10 @@ void TestAdapterSketchFromWeights(bool with_group) { ValidateCuts(cuts, dmat.get(), kBins); if (with_group) { + dmat->Info().weights_ = decltype(dmat->Info().weights_)(); // remove weight HistogramCuts non_weighted = DeviceSketch(0, dmat.get(), kBins, 0); for (size_t i = 0; i < cuts.Values().size(); ++i) { - EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]); + ASSERT_EQ(cuts.Values()[i], non_weighted.Values()[i]); } for (size_t i = 0; i < cuts.MinValues().size(); ++i) { ASSERT_EQ(cuts.MinValues()[i], non_weighted.MinValues()[i]); @@ -604,6 +605,24 @@ void TestAdapterSketchFromWeights(bool with_group) { ASSERT_EQ(cuts.Ptrs().at(i), non_weighted.Ptrs().at(i)); } } + + if (with_group) { + common::HistogramCuts weighted; + auto& h_weights = info.weights_.HostVector(); + h_weights.resize(kGroups); + // Generate different weight. + for (size_t i = 0; i < h_weights.size(); ++i) { + // FIXME(jiamingy): Some entries generated GPU test cannot pass the validate cuts if + // we use more diverse weights, partially caused by + // https://github.com/dmlc/xgboost/issues/7946 + h_weights[i] = (i % 2 == 0 ? 1 : 2) / static_cast(kGroups); + } + SketchContainer sketch_container(ft, kBins, kCols, kRows, 0); + AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits::quiet_NaN(), + &sketch_container); + sketch_container.MakeCuts(&weighted); + ValidateCuts(weighted, dmat.get(), kBins); + } } TEST(HistUtil, AdapterSketchFromWeights) { diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index 6a370c59700b..5b16bd0b14c4 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -98,7 +98,11 @@ inline void TestBinDistribution(const HistogramCuts &cuts, int column_idx, int num_bins) { std::map bin_weights; for (auto i = 0ull; i < sorted_column.size(); i++) { - bin_weights[cuts.SearchBin(sorted_column[i], column_idx)] += sorted_weights[i]; + auto bin_idx = cuts.SearchBin(sorted_column[i], column_idx); + if (bin_weights.find(bin_idx) == bin_weights.cend()) { + bin_weights[bin_idx] = 0; + } + bin_weights.at(bin_idx) += sorted_weights[i]; } int local_num_bins = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx]; auto total_weight = std::accumulate(sorted_weights.begin(), sorted_weights.end(),0); @@ -176,8 +180,7 @@ inline void ValidateColumn(const HistogramCuts& cuts, int column_idx, } } -inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat, - int num_bins) { +inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat, int num_bins) { // Collect data into columns std::vector> columns(dmat->Info().num_col_); for (auto& batch : dmat->GetBatches()) { @@ -189,17 +192,22 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat, } } } + + // construct weights. + std::vector w = dmat->Info().group_ptr_.empty() ? dmat->Info().weights_.HostVector() + : detail::UnrollGroupWeights(dmat->Info()); + // Sort for (auto i = 0ull; i < columns.size(); i++) { auto& col = columns.at(i); - const auto& w = dmat->Info().weights_.HostVector(); - std::vector index(col.size()); + std::vector index(col.size()); std::iota(index.begin(), index.end(), 0); - std::sort(index.begin(), index.end(), - [=](size_t a, size_t b) { return col[a] < col[b]; }); + std::sort(index.begin(), index.end(), [=](size_t a, size_t b) { return col[a] < col[b]; }); std::vector sorted_column(col.size()); std::vector sorted_weights(col.size(), 1.0); + const auto& w = dmat->Info().weights_.HostVector(); + for (auto j = 0ull; j < col.size(); j++) { sorted_column[j] = col[index[j]]; if (w.size() == col.size()) {