Skip to content

Commit

Permalink
optimizations for subsampling in InitData
Browse files Browse the repository at this point in the history
  • Loading branch information
SHVETS, KIRILL committed Apr 15, 2020
1 parent c481f96 commit 93c5f3d
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions tests/cpp/tree/test_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,31 @@ class QuantileHistMock : public QuantileHistMaker {
}
}

void TestInitDataSampling(const GHistIndexMatrix& gmat,
const std::vector<GradientPair>& gpair,
DMatrix* p_fmat,
const RegTree& tree) {
const size_t nthreads = omp_get_num_threads();
// save state of global rng engine
auto initial_rnd = common::GlobalRandom();
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
std::vector<size_t> row_indices_initial = *row_set_collection_.Data();

for (size_t i_nthreads = 1; i_nthreads < 4; ++i_nthreads) {
omp_set_num_threads(i_nthreads);
// return initial state of global rng engine
common::GlobalRandom() = initial_rnd;
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
std::vector<size_t>& row_indices = *row_set_collection_.Data();
ASSERT_EQ(row_indices_initial.size(), row_indices.size());
for (size_t i = 0; i < row_indices_initial.size(); ++i) {
ASSERT_EQ(row_indices_initial[i], row_indices[i]);
}
}
omp_set_num_threads(nthreads);
}


void TestBuildHist(int nid,
const GHistIndexMatrix& gmat,
const DMatrix& fmat,
Expand Down Expand Up @@ -266,6 +291,20 @@ class QuantileHistMock : public QuantileHistMaker {
builder_->TestInitData(gmat, gpair, dmat_.get(), tree);
}

void TestInitDataSampling() {
size_t constexpr kMaxBins = 4;
common::GHistIndexMatrix gmat;
gmat.Init(dmat_.get(), kMaxBins);

RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_);

std::vector<GradientPair> gpair =
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };

builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree);
}
void TestBuildHist() {
RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_);
Expand All @@ -292,6 +331,15 @@ TEST(QuantileHist, InitData) {
maker.TestInitData();
}

TEST(QuantileHist, InitDataSampling) {
const float subsample = 0.5;
std::vector<std::pair<std::string, std::string>> cfg
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())},
{"subsample", std::to_string(subsample)}};
QuantileHistMock maker(cfg);
maker.TestInitDataSampling();
}

TEST(QuantileHist, BuildHist) {
// Don't enable feature grouping
std::vector<std::pair<std::string, std::string>> cfg
Expand Down

0 comments on commit 93c5f3d

Please sign in to comment.