Skip to content

Commit

Permalink
Pass n_threads.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Oct 27, 2021
1 parent d579b50 commit 7140ba6
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 18 deletions.
13 changes: 8 additions & 5 deletions src/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,18 +259,19 @@ GHistBuilder<double>::BuildHist<false>(const std::vector<GradientPair> &gpair,
const GHistIndexMatrix &gmat,
GHistRow<double> hist) const;

template<typename GradientSumT>
template <typename GradientSumT>
void GHistBuilder<GradientSumT>::SubtractionTrick(GHistRowT self,
GHistRowT sibling,
GHistRowT parent) {
GHistRowT parent,
int32_t n_threads) {
const size_t size = self.size();
CHECK_EQ(sibling.size(), size);
CHECK_EQ(parent.size(), size);

const size_t block_size = 1024; // aproximatly 1024 values per block
size_t n_blocks = size/block_size + !!(size%block_size);

ParallelFor(omp_ulong(n_blocks), [&](omp_ulong iblock) {
ParallelFor(omp_ulong(n_blocks), n_threads, [&](omp_ulong iblock) {
const size_t ibegin = iblock * block_size;
const size_t iend =
(((iblock + 1) * block_size > size) ? size : ibegin + block_size);
Expand All @@ -280,11 +281,13 @@ void GHistBuilder<GradientSumT>::SubtractionTrick(GHistRowT self,
template
void GHistBuilder<float>::SubtractionTrick(GHistRow<float> self,
GHistRow<float> sibling,
GHistRow<float> parent);
GHistRow<float> parent,
int32_t n_threads);
template
void GHistBuilder<double>::SubtractionTrick(GHistRow<double> self,
GHistRow<double> sibling,
GHistRow<double> parent);
GHistRow<double> parent,
int32_t n_threads);

} // namespace common
} // namespace xgboost
14 changes: 5 additions & 9 deletions src/common/hist_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -567,26 +567,22 @@ class GHistBuilder {
using GHistRowT = GHistRow<GradientSumT>;

GHistBuilder() = default;
GHistBuilder(size_t nthread, uint32_t nbins) : nthread_{nthread}, nbins_{nbins} {}
explicit GHistBuilder(uint32_t nbins): nbins_{nbins} {}

// construct a histogram via histogram aggregation
template <bool any_missing>
void BuildHist(const std::vector<GradientPair>& gpair,
void BuildHist(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix& gmat,
GHistRowT hist) const;
const GHistIndexMatrix &gmat, GHistRowT hist) const;
// construct a histogram via subtraction trick
void SubtractionTrick(GHistRowT self,
GHistRowT sibling,
GHistRowT parent);
void SubtractionTrick(GHistRowT self, GHistRowT sibling, GHistRowT parent,
int32_t n_threads);

uint32_t GetNumBins() const {
return nbins_;
}

private:
/*! \brief number of threads for parallel computation */
size_t nthread_ { 0 };
/*! \brief number of all bins over all features */
uint32_t nbins_ { 0 };
};
Expand Down
3 changes: 1 addition & 2 deletions src/common/partition_builder.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

/*!
* Copyright 2021 by Contributors
* \file row_set.h
Expand Down Expand Up @@ -236,7 +235,7 @@ class PartitionBuilder {
return blocks_offsets_[nid] + begin / BlockSize;
}

protected:
private:
struct BlockInfo{
size_t n_left;
size_t n_right;
Expand Down
2 changes: 1 addition & 1 deletion src/tree/hist/histogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ template <typename GradientSumT, typename ExpandEntry> class HistogramBuilder {
hist_.Init(total_bins);
hist_local_worker_.Init(total_bins);
buffer_.Init(total_bins);
builder_ = common::GHistBuilder<GradientSumT>(n_threads, total_bins);
builder_ = common::GHistBuilder<GradientSumT>(total_bins);
is_distributed_ = is_distributed;
}

Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/tree/hist/test_evaluate_splits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ template <typename GradientSumT> void TestEvaluateSplits() {
std::iota(row_indices.begin(), row_indices.end(), 0);
row_set_collection.Init();

auto hist_builder = GHistBuilder<GradientSumT>(n_threads, gmat.cut.Ptrs().back());
auto hist_builder = GHistBuilder<GradientSumT>(gmat.cut.Ptrs().back());
hist.Init(gmat.cut.Ptrs().back());
hist.AddHistRow(0);
hist.AllocateAllData();
Expand Down

0 comments on commit 7140ba6

Please sign in to comment.