diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index 75e3fa586a37..c67b2b0e7a45 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -286,15 +286,28 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) { } } - Init(&sketchs, max_num_bins); + Init(&sketchs, max_num_bins, info.num_row_); monitor_.Stop(__FUNCTION__); } +/** + * \param [in,out] in_sketchs + * \param max_num_bins The maximum number bins. + * \param max_rows Number of rows in this DMatrix. + */ void DenseCuts::Init -(std::vector* in_sketchs, uint32_t max_num_bins) { +(std::vector* in_sketchs, uint32_t max_num_bins, size_t max_rows) { monitor_.Start(__func__); std::vector& sketchs = *in_sketchs; + + // Compute how many cuts samples we need at each node + // Do not require more than the number of total rows in training data + // This allows efficient training on wide data + size_t global_max_rows = max_rows; + rabit::Allreduce(&global_max_rows, 1); constexpr int kFactor = 8; + size_t intermediate_num_cuts = + std::min(global_max_rows, static_cast(max_num_bins * kFactor)); // gather the histogram data rabit::SerializeReducer sreducer; std::vector summary_array; @@ -302,11 +315,11 @@ void DenseCuts::Init for (size_t i = 0; i < sketchs.size(); ++i) { WQSketch::SummaryContainer out; sketchs[i].GetSummary(&out); - summary_array[i].Reserve(max_num_bins * kFactor); - summary_array[i].SetPrune(out, max_num_bins * kFactor); + summary_array[i].Reserve(intermediate_num_cuts); + summary_array[i].SetPrune(out, intermediate_num_cuts); } CHECK_EQ(summary_array.size(), in_sketchs->size()); - size_t nbytes = WQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor); + size_t nbytes = WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts); // TODO(chenqin): rabit failure recovery assumes no boostrap onetime call after loadcheckpoint // we need to move this allreduce before loadcheckpoint call in future sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size()); diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 2420ef36adea..89a59ec8b19b 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -148,7 +148,7 @@ class GPUSketcher { this->SketchBatch(batch, info); } - hmat->Init(&sketch_container_->sketches_, max_bin_); + hmat->Init(&sketch_container_->sketches_, max_bin_, info.num_row_); return row_stride_; } diff --git a/src/common/hist_util.h b/src/common/hist_util.h index a47eae4ae807..41c3ee82514c 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -244,7 +244,7 @@ class DenseCuts : public CutsBuilder { CutsBuilder(container) { monitor_.Init(__FUNCTION__); } - void Init(std::vector* sketchs, uint32_t max_num_bins); + void Init(std::vector* sketchs, uint32_t max_num_bins, size_t max_rows); void Build(DMatrix* p_fmat, uint32_t max_num_bins) override; }; diff --git a/src/common/quantile.h b/src/common/quantile.h index 9ad8aa2537d2..27e5560a9933 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -702,6 +702,7 @@ class QuantileSketchTemplate { nlevel = 1; while (true) { limit_size = static_cast(ceil(nlevel / eps)) + 1; + limit_size = std::min(maxn, limit_size); size_t n = (1ULL << nlevel); if (n * limit_size >= maxn) break; ++nlevel; @@ -709,7 +710,8 @@ class QuantileSketchTemplate { // check invariant size_t n = (1ULL << nlevel); CHECK(n * limit_size >= maxn) << "invalid init parameter"; - CHECK(nlevel <= limit_size * eps) << "invalid init parameter"; + CHECK(nlevel <= std::max(1, static_cast(limit_size * eps))) + << "invalid init parameter"; } /*!