From 7140ba617a99af62a7996449b2bd4ed646a1f00a Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 20 Oct 2021 18:51:03 +0800 Subject: [PATCH] Pass n_threads. --- src/common/hist_util.cc | 13 ++++++++----- src/common/hist_util.h | 14 +++++--------- src/common/partition_builder.h | 3 +-- src/tree/hist/histogram.h | 2 +- tests/cpp/tree/hist/test_evaluate_splits.cc | 2 +- 5 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index 6431487bb4b6..71cc76611581 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -259,10 +259,11 @@ GHistBuilder::BuildHist(const std::vector &gpair, const GHistIndexMatrix &gmat, GHistRow hist) const; -template +template void GHistBuilder::SubtractionTrick(GHistRowT self, GHistRowT sibling, - GHistRowT parent) { + GHistRowT parent, + int32_t n_threads) { const size_t size = self.size(); CHECK_EQ(sibling.size(), size); CHECK_EQ(parent.size(), size); @@ -270,7 +271,7 @@ void GHistBuilder::SubtractionTrick(GHistRowT self, const size_t block_size = 1024; // aproximatly 1024 values per block size_t n_blocks = size/block_size + !!(size%block_size); - ParallelFor(omp_ulong(n_blocks), [&](omp_ulong iblock) { + ParallelFor(omp_ulong(n_blocks), n_threads, [&](omp_ulong iblock) { const size_t ibegin = iblock * block_size; const size_t iend = (((iblock + 1) * block_size > size) ? size : ibegin + block_size); @@ -280,11 +281,13 @@ void GHistBuilder::SubtractionTrick(GHistRowT self, template void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, - GHistRow parent); + GHistRow parent, + int32_t n_threads); template void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, - GHistRow parent); + GHistRow parent, + int32_t n_threads); } // namespace common } // namespace xgboost diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 1fa3a2240258..3cc130e8135d 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -567,26 +567,22 @@ class GHistBuilder { using GHistRowT = GHistRow; GHistBuilder() = default; - GHistBuilder(size_t nthread, uint32_t nbins) : nthread_{nthread}, nbins_{nbins} {} + explicit GHistBuilder(uint32_t nbins): nbins_{nbins} {} // construct a histogram via histogram aggregation template - void BuildHist(const std::vector& gpair, + void BuildHist(const std::vector &gpair, const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, - GHistRowT hist) const; + const GHistIndexMatrix &gmat, GHistRowT hist) const; // construct a histogram via subtraction trick - void SubtractionTrick(GHistRowT self, - GHistRowT sibling, - GHistRowT parent); + void SubtractionTrick(GHistRowT self, GHistRowT sibling, GHistRowT parent, + int32_t n_threads); uint32_t GetNumBins() const { return nbins_; } private: - /*! \brief number of threads for parallel computation */ - size_t nthread_ { 0 }; /*! \brief number of all bins over all features */ uint32_t nbins_ { 0 }; }; diff --git a/src/common/partition_builder.h b/src/common/partition_builder.h index 5ffe34988968..0a59b6522428 100644 --- a/src/common/partition_builder.h +++ b/src/common/partition_builder.h @@ -1,4 +1,3 @@ - /*! * Copyright 2021 by Contributors * \file row_set.h @@ -236,7 +235,7 @@ class PartitionBuilder { return blocks_offsets_[nid] + begin / BlockSize; } - protected: + private: struct BlockInfo{ size_t n_left; size_t n_right; diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index 2352310d9b9e..0ff16fdc60fe 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -51,7 +51,7 @@ template class HistogramBuilder { hist_.Init(total_bins); hist_local_worker_.Init(total_bins); buffer_.Init(total_bins); - builder_ = common::GHistBuilder(n_threads, total_bins); + builder_ = common::GHistBuilder(total_bins); is_distributed_ = is_distributed; } diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index cb0171269305..fe7b59ccee46 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -39,7 +39,7 @@ template void TestEvaluateSplits() { std::iota(row_indices.begin(), row_indices.end(), 0); row_set_collection.Init(); - auto hist_builder = GHistBuilder(n_threads, gmat.cut.Ptrs().back()); + auto hist_builder = GHistBuilder(gmat.cut.Ptrs().back()); hist.Init(gmat.cut.Ptrs().back()); hist.AddHistRow(0); hist.AllocateAllData();