Skip to content

Commit

Permalink
Optimized BuildHist function (#5156)
Browse files Browse the repository at this point in the history
  • Loading branch information
SmirnovEgorRu authored Jan 30, 2020
1 parent 4240dae commit c671632
Show file tree
Hide file tree
Showing 8 changed files with 611 additions and 185 deletions.
134 changes: 61 additions & 73 deletions src/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -659,93 +659,87 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
}
}

/*!
* \brief fill a histogram by zeroes
*/
void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end) {
memset(hist.data() + begin, '\0', (end-begin)*sizeof(tree::GradStats));
}

/*!
* \brief Increment hist as dst += add in range [begin, end)
*/
void IncrementHist(GHistRow dst, const GHistRow add, size_t begin, size_t end) {
using FPType = decltype(tree::GradStats::sum_grad);
FPType* pdst = reinterpret_cast<FPType*>(dst.data());
const FPType* padd = reinterpret_cast<const FPType*>(add.data());

for (size_t i = 2 * begin; i < 2 * end; ++i) {
pdst[i] += padd[i];
}
}

/*!
* \brief Copy hist from src to dst in range [begin, end)
*/
void CopyHist(GHistRow dst, const GHistRow src, size_t begin, size_t end) {
using FPType = decltype(tree::GradStats::sum_grad);
FPType* pdst = reinterpret_cast<FPType*>(dst.data());
const FPType* psrc = reinterpret_cast<const FPType*>(src.data());

for (size_t i = 2 * begin; i < 2 * end; ++i) {
pdst[i] = psrc[i];
}
}

/*!
* \brief Compute Subtraction: dst = src1 - src2 in range [begin, end)
*/
void SubtractionHist(GHistRow dst, const GHistRow src1, const GHistRow src2,
size_t begin, size_t end) {
using FPType = decltype(tree::GradStats::sum_grad);
FPType* pdst = reinterpret_cast<FPType*>(dst.data());
const FPType* psrc1 = reinterpret_cast<const FPType*>(src1.data());
const FPType* psrc2 = reinterpret_cast<const FPType*>(src2.data());

for (size_t i = 2 * begin; i < 2 * end; ++i) {
pdst[i] = psrc1[i] - psrc2[i];
}
}


void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix& gmat,
GHistRow hist) {
const size_t nthread = static_cast<size_t>(this->nthread_);
data_.resize(nbins_ * nthread_);

const size_t* rid = row_indices.begin;
const size_t nrows = row_indices.Size();
const uint32_t* index = gmat.index.data();
const size_t* row_ptr = gmat.row_ptr.data();
const float* pgh = reinterpret_cast<const float*>(gpair.data());

double* hist_data = reinterpret_cast<double*>(hist.data());
double* data = reinterpret_cast<double*>(data_.data());

const size_t block_size = 512;
size_t n_blocks = nrows/block_size;
n_blocks += !!(nrows - n_blocks*block_size);

const size_t nthread_to_process = std::min(nthread, n_blocks);
memset(thread_init_.data(), '\0', nthread_to_process*sizeof(size_t));

const size_t cache_line_size = 64;
const size_t prefetch_offset = 10;
size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid);
no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size;

#pragma omp parallel for num_threads(nthread_to_process) schedule(guided)
for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) {
dmlc::omp_uint tid = omp_get_thread_num();
double* data_local_hist = ((nthread_to_process == 1) ? hist_data :
reinterpret_cast<double*>(data_.data() + tid * nbins_));
for (size_t i = 0; i < nrows; ++i) {
const size_t icol_start = row_ptr[rid[i]];
const size_t icol_end = row_ptr[rid[i]+1];

if (!thread_init_[tid]) {
memset(data_local_hist, '\0', 2*nbins_*sizeof(double));
thread_init_[tid] = true;
if (i < nrows - no_prefetch_size) {
PREFETCH_READ_T0(row_ptr + rid[i + prefetch_offset]);
PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]);
}

const size_t istart = iblock*block_size;
const size_t iend = (((iblock+1)*block_size > nrows) ? nrows : istart + block_size);
for (size_t i = istart; i < iend; ++i) {
const size_t icol_start = row_ptr[rid[i]];
const size_t icol_end = row_ptr[rid[i]+1];

if (i < nrows - no_prefetch_size) {
PREFETCH_READ_T0(row_ptr + rid[i + prefetch_offset]);
PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]);
}
for (size_t j = icol_start; j < icol_end; ++j) {
const uint32_t idx_bin = 2*index[j];
const size_t idx_gh = 2*rid[i];

for (size_t j = icol_start; j < icol_end; ++j) {
const uint32_t idx_bin = 2*index[j];
const size_t idx_gh = 2*rid[i];

data_local_hist[idx_bin] += pgh[idx_gh];
data_local_hist[idx_bin+1] += pgh[idx_gh+1];
}
}
}

if (nthread_to_process > 1) {
const size_t size = (2*nbins_);
const size_t block_size = 1024;
size_t n_blocks = size/block_size;
n_blocks += !!(size - n_blocks*block_size);

size_t n_worked_bins = 0;
for (size_t i = 0; i < nthread_to_process; ++i) {
if (thread_init_[i]) {
thread_init_[n_worked_bins++] = i;
}
}

#pragma omp parallel for num_threads(std::min(nthread, n_blocks)) schedule(guided)
for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) {
const size_t istart = iblock * block_size;
const size_t iend = (((iblock + 1) * block_size > size) ? size : istart + block_size);

const size_t bin = 2 * thread_init_[0] * nbins_;
memcpy(hist_data + istart, (data + bin + istart), sizeof(double) * (iend - istart));

for (size_t i_bin_part = 1; i_bin_part < n_worked_bins; ++i_bin_part) {
const size_t bin = 2 * thread_init_[i_bin_part] * nbins_;
for (size_t i = istart; i < iend; i++) {
hist_data[i] += data[bin + i];
}
}
hist_data[idx_bin] += pgh[idx_gh];
hist_data[idx_bin+1] += pgh[idx_gh+1];
}
}
}
Expand Down Expand Up @@ -801,10 +795,6 @@ void GHistBuilder::BuildBlockHist(const std::vector<GradientPair>& gpair,
}

void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
tree::GradStats* p_self = self.data();
tree::GradStats* p_sibling = sibling.data();
tree::GradStats* p_parent = parent.data();

const size_t size = self.size();
CHECK_EQ(sibling.size(), size);
CHECK_EQ(parent.size(), size);
Expand All @@ -816,9 +806,7 @@ void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow pa
for (omp_ulong iblock = 0; iblock < n_blocks; ++iblock) {
const size_t ibegin = iblock*block_size;
const size_t iend = (((iblock+1)*block_size > size) ? size : ibegin + block_size);
for (bst_omp_uint bin_id = ibegin; bin_id < iend; bin_id++) {
p_self[bin_id].SetSubstract(p_parent[bin_id], p_sibling[bin_id]);
}
SubtractionHist(self, parent, sibling, ibegin, iend);
}
}

Expand Down
Loading

0 comments on commit c671632

Please sign in to comment.