Skip to content

Commit

Permalink
optimizations for subsampling in InitData
Browse files Browse the repository at this point in the history
  • Loading branch information
SHVETS, KIRILL committed Apr 13, 2020
1 parent 6671b42 commit 73a3d6b
Showing 1 changed file with 41 additions and 1 deletion.
42 changes: 41 additions & 1 deletion src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,15 +574,55 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
CHECK_EQ(param_.sampling_method, TrainParam::kUniform)
<< "Only uniform sampling is supported, "
<< "gradient-based sampling is only support by GPU Hist.";
std::bernoulli_distribution coin_flip(param_.subsample);
auto& rnd = common::GlobalRandom();
#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG
const size_t nthread = this->nthread_;
std::vector<size_t> row_offsets(nthread, 0);
/* usage of mt19937_64 give 2x speed up for subsampling */
std::vector<std::mt19937> rnds(nthread);
/* create engine for each thread */
for (std::mt19937& r : rnds) {
r = rnd;
}
const size_t discard_size = info.num_row_ / nthread;
#pragma omp parallel num_threads(nthread)
{
const size_t tid = omp_get_thread_num();
const size_t ibegin = tid * discard_size;
const size_t iend = (tid == (nthread - 1)) ?
info.num_row_ : ibegin + discard_size;
std::bernoulli_distribution coin_flip(param_.subsample);

rnds[tid].discard(2*discard_size * tid);
for (size_t i = ibegin; i < iend; ++i) {
if (gpair[i].GetHess() >= 0.0f && coin_flip(rnds[tid])) {
p_row_indices[ibegin + row_offsets[tid]++] = i;
}
}
}
/* discard global engine */
rnd = rnds[nthread - 1];
size_t prefix_sum = row_offsets[0];
for (size_t i = 1; i < nthread; ++i) {
const size_t ibegin = i * discard_size;

for (size_t k = 0; k < row_offsets[i]; ++k) {
row_indices[prefix_sum + k] = row_indices[ibegin + k];
}
prefix_sum += row_offsets[i];
}
/* resize row_indices to reduce memory */
row_indices.resize(prefix_sum);
#else
std::bernoulli_distribution coin_flip(param_.subsample);
size_t j = 0;
for (size_t i = 0; i < info.num_row_; ++i) {
if (gpair[i].GetHess() >= 0.0f && coin_flip(rnd)) {
p_row_indices[j++] = i;
}
}
row_indices.resize(j);
#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG
} else {
MemStackAllocator<bool, 128> buff(this->nthread_);
bool* p_buff = buff.Get();
Expand Down

0 comments on commit 73a3d6b

Please sign in to comment.