diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index d9003841cdd2..05dd9cc060a9 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -536,6 +536,63 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( return true; } +void QuantileHistMaker::Builder::InitSampling(const std::vector& gpair, + const DMatrix& fmat, + std::vector* row_indices) { + const auto& info = fmat.Info(); + auto& rnd = common::GlobalRandom(); + std::vector& row_indices_local = *row_indices; + size_t* p_row_indices = row_indices_local.data(); +#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG + std::bernoulli_distribution coin_flip(param_.subsample); + size_t j = 0; + for (size_t i = 0; i < info.num_row_; ++i) { + if (gpair[i].GetHess() >= 0.0f && coin_flip(rnd)) { + p_row_indices[j++] = i; + } + } + /* resize row_indices to reduce memory */ + row_indices_local.resize(j); +#else + const size_t nthread = this->nthread_; + std::vector row_offsets(nthread, 0); + /* usage of mt19937_64 give 2x speed up for subsampling */ + std::vector rnds(nthread); + /* create engine for each thread */ + for (std::mt19937& r : rnds) { + r = rnd; + } + const size_t discard_size = info.num_row_ / nthread; + #pragma omp parallel num_threads(nthread) + { + const size_t tid = omp_get_thread_num(); + const size_t ibegin = tid * discard_size; + const size_t iend = (tid == (nthread - 1)) ? + info.num_row_ : ibegin + discard_size; + std::bernoulli_distribution coin_flip(param_.subsample); + + rnds[tid].discard(2*discard_size * tid); + for (size_t i = ibegin; i < iend; ++i) { + if (gpair[i].GetHess() >= 0.0f && coin_flip(rnds[tid])) { + p_row_indices[ibegin + row_offsets[tid]++] = i; + } + } + } + /* discard global engine */ + rnd = rnds[nthread - 1]; + size_t prefix_sum = row_offsets[0]; + for (size_t i = 1; i < nthread; ++i) { + const size_t ibegin = i * discard_size; + + for (size_t k = 0; k < row_offsets[i]; ++k) { + row_indices_local[prefix_sum + k] = row_indices_local[ibegin + k]; + } + prefix_sum += row_offsets[i]; + } + /* resize row_indices to reduce memory */ + row_indices_local.resize(prefix_sum); +#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG +} void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, const std::vector& gpair, const DMatrix& fmat, @@ -569,22 +626,14 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, std::vector& row_indices = *row_set_collection_.Data(); row_indices.resize(info.num_row_); - auto* p_row_indices = row_indices.data(); + size_t* p_row_indices = row_indices.data(); // mark subsample and build list of member rows if (param_.subsample < 1.0f) { CHECK_EQ(param_.sampling_method, TrainParam::kUniform) << "Only uniform sampling is supported, " << "gradient-based sampling is only support by GPU Hist."; - std::bernoulli_distribution coin_flip(param_.subsample); - auto& rnd = common::GlobalRandom(); - size_t j = 0; - for (size_t i = 0; i < info.num_row_; ++i) { - if (gpair[i].GetHess() >= 0.0f && coin_flip(rnd)) { - p_row_indices[j++] = i; - } - } - row_indices.resize(j); + InitSampling(gpair, fmat, &row_indices); } else { MemStackAllocator buff(this->nthread_); bool* p_buff = buff.Get(); diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 5d0e09772e5c..0c7acbb6d539 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -202,6 +202,9 @@ class QuantileHistMaker: public TreeUpdater { const DMatrix& fmat, const RegTree& tree); + void InitSampling(const std::vector& gpair, + const DMatrix& fmat, std::vector* row_indices); + void EvaluateSplits(const std::vector& nodes_set, const GHistIndexMatrix& gmat, const HistCollection& hist, diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index b93615c3fe2a..ad07930b4b1b 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -96,6 +96,31 @@ class QuantileHistMock : public QuantileHistMaker { } } + void TestInitDataSampling(const GHistIndexMatrix& gmat, + const std::vector& gpair, + DMatrix* p_fmat, + const RegTree& tree) { + const size_t nthreads = omp_get_num_threads(); + // save state of global rng engine + auto initial_rnd = common::GlobalRandom(); + RealImpl::InitData(gmat, gpair, *p_fmat, tree); + std::vector row_indices_initial = *row_set_collection_.Data(); + + for (size_t i_nthreads = 1; i_nthreads < 4; ++i_nthreads) { + omp_set_num_threads(i_nthreads); + // return initial state of global rng engine + common::GlobalRandom() = initial_rnd; + RealImpl::InitData(gmat, gpair, *p_fmat, tree); + std::vector& row_indices = *row_set_collection_.Data(); + ASSERT_EQ(row_indices_initial.size(), row_indices.size()); + for (size_t i = 0; i < row_indices_initial.size(); ++i) { + ASSERT_EQ(row_indices_initial[i], row_indices[i]); + } + } + omp_set_num_threads(nthreads); + } + + void TestBuildHist(int nid, const GHistIndexMatrix& gmat, const DMatrix& fmat, @@ -266,6 +291,20 @@ class QuantileHistMock : public QuantileHistMaker { builder_->TestInitData(gmat, gpair, dmat_.get(), tree); } + void TestInitDataSampling() { + size_t constexpr kMaxBins = 4; + common::GHistIndexMatrix gmat; + gmat.Init(dmat_.get(), kMaxBins); + + RegTree tree = RegTree(); + tree.param.UpdateAllowUnknown(cfg_); + + std::vector gpair = + { {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, + {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} }; + + builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree); + } void TestBuildHist() { RegTree tree = RegTree(); tree.param.UpdateAllowUnknown(cfg_); @@ -292,6 +331,15 @@ TEST(QuantileHist, InitData) { maker.TestInitData(); } +TEST(QuantileHist, InitDataSampling) { + const float subsample = 0.5; + std::vector> cfg + {{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}, + {"subsample", std::to_string(subsample)}}; + QuantileHistMock maker(cfg); + maker.TestInitDataSampling(); +} + TEST(QuantileHist, BuildHist) { // Don't enable feature grouping std::vector> cfg