diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index b93615c3fe2a..ad07930b4b1b 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -96,6 +96,31 @@ class QuantileHistMock : public QuantileHistMaker { } } + void TestInitDataSampling(const GHistIndexMatrix& gmat, + const std::vector& gpair, + DMatrix* p_fmat, + const RegTree& tree) { + const size_t nthreads = omp_get_num_threads(); + // save state of global rng engine + auto initial_rnd = common::GlobalRandom(); + RealImpl::InitData(gmat, gpair, *p_fmat, tree); + std::vector row_indices_initial = *row_set_collection_.Data(); + + for (size_t i_nthreads = 1; i_nthreads < 4; ++i_nthreads) { + omp_set_num_threads(i_nthreads); + // return initial state of global rng engine + common::GlobalRandom() = initial_rnd; + RealImpl::InitData(gmat, gpair, *p_fmat, tree); + std::vector& row_indices = *row_set_collection_.Data(); + ASSERT_EQ(row_indices_initial.size(), row_indices.size()); + for (size_t i = 0; i < row_indices_initial.size(); ++i) { + ASSERT_EQ(row_indices_initial[i], row_indices[i]); + } + } + omp_set_num_threads(nthreads); + } + + void TestBuildHist(int nid, const GHistIndexMatrix& gmat, const DMatrix& fmat, @@ -266,6 +291,20 @@ class QuantileHistMock : public QuantileHistMaker { builder_->TestInitData(gmat, gpair, dmat_.get(), tree); } + void TestInitDataSampling() { + size_t constexpr kMaxBins = 4; + common::GHistIndexMatrix gmat; + gmat.Init(dmat_.get(), kMaxBins); + + RegTree tree = RegTree(); + tree.param.UpdateAllowUnknown(cfg_); + + std::vector gpair = + { {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, + {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} }; + + builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree); + } void TestBuildHist() { RegTree tree = RegTree(); tree.param.UpdateAllowUnknown(cfg_); @@ -292,6 +331,15 @@ TEST(QuantileHist, InitData) { maker.TestInitData(); } +TEST(QuantileHist, InitDataSampling) { + const float subsample = 0.5; + std::vector> cfg + {{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}, + {"subsample", std::to_string(subsample)}}; + QuantileHistMock maker(cfg); + maker.TestInitDataSampling(); +} + TEST(QuantileHist, BuildHist) { // Don't enable feature grouping std::vector> cfg