diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index edac92d06908..861cfae8adde 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -82,14 +82,16 @@ template class DenseColumn: public Column { public: DenseColumn(ColumnType type, common::Span index, - uint32_t index_base, - const std::vector::const_iterator missing_flags) + uint32_t index_base, const std::vector& missing_flags, + size_t feature_offset) : Column(type, index, index_base), - missing_flags_(missing_flags) {} - bool IsMissing(size_t idx) const { return missing_flags_[idx]; } + missing_flags_(missing_flags), + feature_offset_(feature_offset) {} + bool IsMissing(size_t idx) const { return missing_flags_[feature_offset_ + idx]; } private: /* flags for missing values in dense columns */ - std::vector::const_iterator missing_flags_; + const std::vector& missing_flags_; + size_t feature_offset_; }; /*! \brief a collection of columns, with support for construction from @@ -208,10 +210,8 @@ class ColumnMatrix { column_size }; std::unique_ptr > res; if (type_[fid] == ColumnType::kDenseColumn) { - std::vector::const_iterator column_iterator = missing_flags_.begin(); - advance(column_iterator, feature_offset); // increment iterator to right position res.reset(new DenseColumn(type_[fid], bin_index, index_base_[fid], - column_iterator)); + missing_flags_, feature_offset)); } else { res.reset(new SparseColumn(type_[fid], bin_index, index_base_[fid], {&row_ind_[feature_offset], column_size})); diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 942cda159bd4..38c58b495c3e 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -503,6 +503,15 @@ struct PinnedMemory { return xgboost::common::Span(static_cast(temp_storage), size); } + template + xgboost::common::Span GetSpan(size_t size, T init) { + auto result = this->GetSpan(size); + for (auto &e : result) { + e = init; + } + return result; + } + void Free() { if (temp_storage != nullptr) { safe_cuda(cudaFreeHost(temp_storage)); diff --git a/src/tree/gpu_hist/driver.cuh b/src/tree/gpu_hist/driver.cuh new file mode 100644 index 000000000000..675e877e1945 --- /dev/null +++ b/src/tree/gpu_hist/driver.cuh @@ -0,0 +1,120 @@ +/*! + * Copyright 2020 by XGBoost Contributors + */ +#ifndef DRIVER_CUH_ +#define DRIVER_CUH_ +#include +#include +#include "../param.h" +#include "evaluate_splits.cuh" + +namespace xgboost { +namespace tree { +struct ExpandEntry { + int nid; + int depth; + DeviceSplitCandidate split; + ExpandEntry() = default; + XGBOOST_DEVICE ExpandEntry(int nid, int depth, DeviceSplitCandidate split) + : nid(nid), depth(depth), split(std::move(split)) {} + bool IsValid(const TrainParam& param, int num_leaves) const { + if (split.loss_chg <= kRtEps) return false; + if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) { + return false; + } + if (split.loss_chg < param.min_split_loss) { + return false; + } + if (param.max_depth > 0 && depth == param.max_depth) { + return false; + } + if (param.max_leaves > 0 && num_leaves == param.max_leaves) { + return false; + } + return true; + } + + static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) { + if (param.max_depth > 0 && depth >= param.max_depth) return false; + if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false; + return true; + } + + friend std::ostream& operator<<(std::ostream& os, const ExpandEntry& e) { + os << "ExpandEntry: \n"; + os << "nidx: " << e.nid << "\n"; + os << "depth: " << e.depth << "\n"; + os << "loss: " << e.split.loss_chg << "\n"; + os << "left_sum: " << e.split.left_sum << "\n"; + os << "right_sum: " << e.split.right_sum << "\n"; + return os; + } +}; + +inline bool DepthWise(const ExpandEntry& lhs, const ExpandEntry& rhs) { + return lhs.depth > rhs.depth; // favor small depth +} + +inline bool LossGuide(const ExpandEntry& lhs, const ExpandEntry& rhs) { + if (lhs.split.loss_chg == rhs.split.loss_chg) { + return lhs.nid > rhs.nid; // favor small timestamp + } else { + return lhs.split.loss_chg < rhs.split.loss_chg; // favor large loss_chg + } +} + +// Drives execution of tree building on device +class Driver { + using ExpandQueue = + std::priority_queue, + std::function>; + + public: + explicit Driver(TrainParam::TreeGrowPolicy policy) + : policy_(policy), + queue_(policy == TrainParam::kDepthWise ? DepthWise : LossGuide) {} + template + void Push(EntryIterT begin,EntryIterT end) { + for (auto it = begin; it != end; ++it) { + const ExpandEntry& e = *it; + if (e.split.loss_chg > kRtEps) { + queue_.push(e); + } + } + } + void Push(const std::vector &entries) { + this->Push(entries.begin(), entries.end()); + } + // Return the set of nodes to be expanded + // This set has no dependencies between entries so they may be expanded in + // parallel or asynchronously + std::vector Pop() { + if (queue_.empty()) return {}; + // Return a single entry for loss guided mode + if (policy_ == TrainParam::kLossGuide) { + ExpandEntry e = queue_.top(); + queue_.pop(); + return {e}; + } + // Return nodes on same level for depth wise + std::vector result; + ExpandEntry e = queue_.top(); + int level = e.depth; + while (e.depth == level && !queue_.empty()) { + queue_.pop(); + result.emplace_back(e); + if (!queue_.empty()) { + e = queue_.top(); + } + } + return result; + } + + private: + TrainParam::TreeGrowPolicy policy_; + ExpandQueue queue_; +}; +} // namespace tree +} // namespace xgboost + +#endif // DRIVER_CUH_ diff --git a/src/tree/gpu_hist/row_partitioner.cuh b/src/tree/gpu_hist/row_partitioner.cuh index 7a68616748c2..fd42234fd345 100644 --- a/src/tree/gpu_hist/row_partitioner.cuh +++ b/src/tree/gpu_hist/row_partitioner.cuh @@ -61,6 +61,7 @@ class RowPartitioner { dh::caching_device_vector left_counts_; // Useful to keep a bunch of zeroed memory for sort position std::vector streams_; + dh::PinnedMemory pinned_; public: RowPartitioner(int device_idx, size_t num_rows); @@ -129,12 +130,12 @@ class RowPartitioner { d_position[idx] = new_position; }); // Overlap device to host memory copy (left_count) with sort - int64_t left_count; + int64_t &left_count = pinned_.GetSpan(1)[0]; dh::safe_cuda(cudaMemcpyAsync(&left_count, d_left_count, sizeof(int64_t), cudaMemcpyDeviceToHost, streams_[0])); - SortPositionAndCopy(segment, left_nidx, right_nidx, d_left_count, - streams_[1]); + SortPositionAndCopy(segment, left_nidx, right_nidx, d_left_count, streams_[1] + ); dh::safe_cuda(cudaStreamSynchronize(streams_[0])); CHECK_LE(left_count, segment.Size()); diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 69a3d79770ba..2de692c0a117 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -30,6 +30,7 @@ #include "gpu_hist/row_partitioner.cuh" #include "gpu_hist/histogram.cuh" #include "gpu_hist/evaluate_splits.cuh" +#include "gpu_hist/driver.cuh" namespace xgboost { namespace tree { @@ -57,58 +58,6 @@ struct GPUHistMakerTrainParam DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam); #endif // !defined(GTEST_TEST) -struct ExpandEntry { - int nid; - int depth; - DeviceSplitCandidate split; - uint64_t timestamp; - ExpandEntry() = default; - ExpandEntry(int nid, int depth, DeviceSplitCandidate split, - uint64_t timestamp) - : nid(nid), depth(depth), split(std::move(split)), timestamp(timestamp) {} - bool IsValid(const TrainParam& param, int num_leaves) const { - if (split.loss_chg <= kRtEps) return false; - if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) { - return false; - } - if (split.loss_chg < param.min_split_loss) { return false; } - if (param.max_depth > 0 && depth == param.max_depth) {return false; } - if (param.max_leaves > 0 && num_leaves == param.max_leaves) { return false; } - return true; - } - - static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) { - if (param.max_depth > 0 && depth >= param.max_depth) return false; - if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false; - return true; - } - - friend std::ostream& operator<<(std::ostream& os, const ExpandEntry& e) { - os << "ExpandEntry: \n"; - os << "nidx: " << e.nid << "\n"; - os << "depth: " << e.depth << "\n"; - os << "loss: " << e.split.loss_chg << "\n"; - os << "left_sum: " << e.split.left_sum << "\n"; - os << "right_sum: " << e.split.right_sum << "\n"; - return os; - } -}; - -inline static bool DepthWise(const ExpandEntry& lhs, const ExpandEntry& rhs) { - if (lhs.depth == rhs.depth) { - return lhs.timestamp > rhs.timestamp; // favor small timestamp - } else { - return lhs.depth > rhs.depth; // favor small depth - } -} -inline static bool LossGuide(const ExpandEntry& lhs, const ExpandEntry& rhs) { - if (lhs.split.loss_chg == rhs.split.loss_chg) { - return lhs.timestamp > rhs.timestamp; // favor small timestamp - } else { - return lhs.split.loss_chg < rhs.split.loss_chg; // favor large loss_chg - } -} - /** * \struct DeviceHistogram * @@ -243,6 +192,8 @@ struct GPUHistMakerDevice { GradientSumT histogram_rounding; + dh::PinnedMemory pinned; + std::vector streams{}; common::Monitor monitor; @@ -250,11 +201,6 @@ struct GPUHistMakerDevice { common::ColumnSampler column_sampler; FeatureInteractionConstraintDevice interaction_constraints; - using ExpandQueue = - std::priority_queue, - std::function>; - std::unique_ptr qexpand; - std::unique_ptr sampler; GPUHistMakerDevice(int _device_id, @@ -314,11 +260,6 @@ struct GPUHistMakerDevice { // Note that the column sampler must be passed by value because it is not // thread safe void Reset(HostDeviceVector* dh_gpair, DMatrix* dmat, int64_t num_columns) { - if (param.grow_policy == TrainParam::kLossGuide) { - qexpand.reset(new ExpandQueue(LossGuide)); - } else { - qexpand.reset(new ExpandQueue(DepthWise)); - } this->column_sampler.Init(num_columns, param.colsample_bynode, param.colsample_bylevel, param.colsample_bytree); dh::safe_cuda(cudaSetDevice(device_id)); @@ -370,9 +311,9 @@ struct GPUHistMakerDevice { return result.front(); } - std::vector EvaluateLeftRightSplits( - ExpandEntry candidate, int left_nidx, int right_nidx, - const RegTree& tree) { + void EvaluateLeftRightSplits( + ExpandEntry candidate, int left_nidx, int right_nidx, const RegTree& tree, + common::Span pinned_candidates_out) { dh::TemporaryArray splits_out(2); GPUTrainingParam gpu_param(param); auto left_sampled_features = @@ -412,12 +353,19 @@ struct GPUHistMakerDevice { hist.GetNodeHistogram(right_nidx), node_value_constraints[right_nidx], dh::ToSpan(monotone_constraints)}; - EvaluateSplits(dh::ToSpan(splits_out), left, right); - std::vector result(2); - dh::safe_cuda(cudaMemcpy(result.data(), splits_out.data().get(), - sizeof(DeviceSplitCandidate) * splits_out.size(), - cudaMemcpyDeviceToHost)); - return result; + auto d_splits_out = dh::ToSpan(splits_out); + EvaluateSplits(d_splits_out, left, right); + dh::TemporaryArray entries(2); + auto d_entries = entries.data().get(); + dh::LaunchN(device_id, 1, [=] __device__(size_t idx) { + d_entries[0] = + ExpandEntry(left_nidx, candidate.depth + 1, d_splits_out[0]); + d_entries[1] = + ExpandEntry(right_nidx, candidate.depth + 1, d_splits_out[1]); + }); + dh::safe_cuda(cudaMemcpyAsync( + pinned_candidates_out.data(), entries.data().get(), + sizeof(ExpandEntry) * entries.size(), cudaMemcpyDeviceToHost)); } void BuildHist(int nidx) { @@ -637,7 +585,7 @@ struct GPUHistMakerDevice { tree[candidate.nid].RightChild()); } - void InitRoot(RegTree* p_tree, dh::AllReducer* reducer) { + ExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) { constexpr bst_node_t kRootNIdx = 0; dh::XGBCachingDeviceAllocator alloc; GradientPair root_sum = thrust::reduce( @@ -662,61 +610,66 @@ struct GPUHistMakerDevice { // Generate first split auto split = this->EvaluateRootSplit(root_sum); - qexpand->push( - ExpandEntry(kRootNIdx, p_tree->GetDepth(kRootNIdx), split, 0)); + return ExpandEntry(kRootNIdx, p_tree->GetDepth(kRootNIdx), split); } void UpdateTree(HostDeviceVector* gpair_all, DMatrix* p_fmat, RegTree* p_tree, dh::AllReducer* reducer) { auto& tree = *p_tree; + Driver driver(static_cast(param.grow_policy)); monitor.Start("Reset"); this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_); monitor.Stop("Reset"); monitor.Start("InitRoot"); - this->InitRoot(p_tree, reducer); + driver.Push({ this->InitRoot(p_tree, reducer) }); monitor.Stop("InitRoot"); - auto timestamp = qexpand->size(); auto num_leaves = 1; - while (!qexpand->empty()) { - ExpandEntry candidate = qexpand->top(); - qexpand->pop(); - if (!candidate.IsValid(param, num_leaves)) { - continue; - } - this->ApplySplit(candidate, p_tree); - - num_leaves++; - - int left_child_nidx = tree[candidate.nid].LeftChild(); - int right_child_nidx = tree[candidate.nid].RightChild(); - // Only create child entries if needed - if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), - num_leaves)) { - monitor.Start("UpdatePosition"); - this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]); - monitor.Stop("UpdatePosition"); - - monitor.Start("BuildHist"); - this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer); - monitor.Stop("BuildHist"); - - monitor.Start("EvaluateSplits"); - auto splits = this->EvaluateLeftRightSplits(candidate, left_child_nidx, - right_child_nidx, - *p_tree); - monitor.Stop("EvaluateSplits"); - - qexpand->push(ExpandEntry(left_child_nidx, - tree.GetDepth(left_child_nidx), splits.at(0), - timestamp++)); - qexpand->push(ExpandEntry(right_child_nidx, - tree.GetDepth(right_child_nidx), - splits.at(1), timestamp++)); + // The set of leaves that can be expanded asynchronously + auto expand_set = driver.Pop(); + while (!expand_set.empty()) { + auto new_candidates = + pinned.GetSpan(expand_set.size() * 2, ExpandEntry()); + + for (auto i = 0ull; i < expand_set.size(); i++) { + auto candidate = expand_set.at(i); + if (!candidate.IsValid(param, num_leaves)) { + continue; + } + this->ApplySplit(candidate, p_tree); + + num_leaves++; + + int left_child_nidx = tree[candidate.nid].LeftChild(); + int right_child_nidx = tree[candidate.nid].RightChild(); + // Only create child entries if needed + if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), + num_leaves)) { + monitor.Start("UpdatePosition"); + this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]); + monitor.Stop("UpdatePosition"); + + monitor.Start("BuildHist"); + this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer); + monitor.Stop("BuildHist"); + + monitor.Start("EvaluateSplits"); + this->EvaluateLeftRightSplits(candidate, left_child_nidx, + right_child_nidx, *p_tree, + new_candidates.subspan(i * 2, 2)); + monitor.Stop("EvaluateSplits"); + } else { + // Set default + new_candidates[i * 2] = ExpandEntry(); + new_candidates[i * 2 + 1] = ExpandEntry(); + } } + dh::safe_cuda(cudaDeviceSynchronize()); + driver.Push(new_candidates.begin(), new_candidates.end()); + expand_set = driver.Pop(); } monitor.Start("FinalisePosition"); diff --git a/tests/cpp/test_serialization.cc b/tests/cpp/test_serialization.cc index 80abaf8422f2..e6bf1fef854a 100644 --- a/tests/cpp/test_serialization.cc +++ b/tests/cpp/test_serialization.cc @@ -264,7 +264,7 @@ TEST_F(SerializationTest, CPUCoordDescent) { } #if defined(XGBOOST_USE_CUDA) -TEST_F(SerializationTest, GPUHist) { +TEST_F(SerializationTest, GpuHist) { TestLearnerSerialization({{"booster", "gbtree"}, {"seed", "0"}, {"enable_experimental_json_serialization", "1"}, @@ -441,7 +441,7 @@ TEST_F(LogitSerializationTest, CPUCoordDescent) { } #if defined(XGBOOST_USE_CUDA) -TEST_F(LogitSerializationTest, GPUHist) { +TEST_F(LogitSerializationTest, GpuHist) { TestLearnerSerialization({{"booster", "gbtree"}, {"objective", "binary:logistic"}, {"seed", "0"}, @@ -596,7 +596,7 @@ TEST_F(MultiClassesSerializationTest, CPUCoordDescent) { } #if defined(XGBOOST_USE_CUDA) -TEST_F(MultiClassesSerializationTest, GPUHist) { +TEST_F(MultiClassesSerializationTest, GpuHist) { TestLearnerSerialization({{"booster", "gbtree"}, {"num_class", std::to_string(kClasses)}, {"seed", "0"}, diff --git a/tests/cpp/tree/gpu_hist/test_driver.cu b/tests/cpp/tree/gpu_hist/test_driver.cu new file mode 100644 index 000000000000..25c1c11cb9ab --- /dev/null +++ b/tests/cpp/tree/gpu_hist/test_driver.cu @@ -0,0 +1,59 @@ +#include +#include "../../../../src/tree/gpu_hist/driver.cuh" + +namespace xgboost { +namespace tree { + +TEST(GpuHist, DriverDepthWise) { + Driver driver(TrainParam::kDepthWise); + EXPECT_TRUE(driver.Pop().empty()); + DeviceSplitCandidate split; + split.loss_chg = 1.0f; + ExpandEntry root(0, 0, split); + driver.Push({root}); + EXPECT_EQ(driver.Pop().front().nid, 0); + driver.Push({ExpandEntry{1, 1, split}}); + driver.Push({ExpandEntry{2, 1, split}}); + driver.Push({ExpandEntry{3, 2, split}}); + // Should return entries from level 1 + auto res = driver.Pop(); + EXPECT_EQ(res.size(), 2); + for (auto &e : res) { + EXPECT_EQ(e.depth, 1); + } + res = driver.Pop(); + EXPECT_EQ(res[0].depth, 2); + EXPECT_TRUE(driver.Pop().empty()); +} + +TEST(GpuHist, DriverLossGuided) { + DeviceSplitCandidate high_gain; + high_gain.loss_chg = 5.0f; + DeviceSplitCandidate low_gain; + low_gain.loss_chg = 1.0f; + + Driver driver(TrainParam::kLossGuide); + EXPECT_TRUE(driver.Pop().empty()); + ExpandEntry root(0, 0, high_gain); + driver.Push({root}); + EXPECT_EQ(driver.Pop().front().nid, 0); + // Select high gain first + driver.Push({ExpandEntry{1, 1, low_gain}}); + driver.Push({ExpandEntry{2, 2, high_gain}}); + auto res = driver.Pop(); + EXPECT_EQ(res.size(), 1); + EXPECT_EQ(res[0].nid, 2); + res = driver.Pop(); + EXPECT_EQ(res.size(), 1); + EXPECT_EQ(res[0].nid, 1); + + // If equal gain, use nid + driver.Push({ExpandEntry{2, 1, low_gain}}); + driver.Push({ExpandEntry{1, 1, low_gain}}); + res = driver.Pop(); + EXPECT_EQ(res[0].nid, 1); + res = driver.Pop(); + EXPECT_EQ(res[0].nid, 2); +} +} // namespace tree +} // namespace xgboost diff --git a/tests/cpp/tree/test_tree_stat.cc b/tests/cpp/tree/test_tree_stat.cc index eb727a988756..eb8a7c5d910c 100644 --- a/tests/cpp/tree/test_tree_stat.cc +++ b/tests/cpp/tree/test_tree_stat.cc @@ -40,7 +40,7 @@ class UpdaterTreeStatTest : public ::testing::Test { }; #if defined(XGBOOST_USE_CUDA) -TEST_F(UpdaterTreeStatTest, GPUHist) { +TEST_F(UpdaterTreeStatTest, GpuHist) { this->RunTest("grow_gpu_hist"); } #endif // defined(XGBOOST_USE_CUDA)