Skip to content

Commit

Permalink
Small cleanup to hist tree method. (#7735)
Browse files Browse the repository at this point in the history
* Remove special optimization using number of bins.
* Remove 1-based index for column sampling.
* Remove data layout.
* Unify update prediction cache.
  • Loading branch information
trivialfis authored Mar 19, 2022
1 parent 718472d commit 996cc70
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 205 deletions.
11 changes: 4 additions & 7 deletions src/common/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,8 @@ class ColumnSampler {
* \param colsample_bytree
* \param skip_index_0 (Optional) True to skip index 0.
*/
void Init(int64_t num_col, std::vector<float> feature_weights,
float colsample_bynode, float colsample_bylevel,
float colsample_bytree, bool skip_index_0 = false) {
void Init(int64_t num_col, std::vector<float> feature_weights, float colsample_bynode,
float colsample_bylevel, float colsample_bytree) {
feature_weights_ = std::move(feature_weights);
colsample_bylevel_ = colsample_bylevel;
colsample_bytree_ = colsample_bytree;
Expand All @@ -169,10 +168,8 @@ class ColumnSampler {
}
Reset();

int begin_idx = skip_index_0 ? 1 : 0;
feature_set_tree_->Resize(num_col - begin_idx);
std::iota(feature_set_tree_->HostVector().begin(),
feature_set_tree_->HostVector().end(), begin_idx);
feature_set_tree_->Resize(num_col);
std::iota(feature_set_tree_->HostVector().begin(), feature_set_tree_->HostVector().end(), 0);

feature_set_tree_ = ColSample(feature_set_tree_, colsample_bytree_);
}
Expand Down
34 changes: 16 additions & 18 deletions src/common/row_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ class RowSetCollection {
/*! \brief return corresponding element set given the node_id */
inline const Elem& operator[](unsigned node_id) const {
const Elem& e = elem_of_each_node_[node_id];
CHECK(e.begin != nullptr)
<< "access element that is not in the set";
return e;
}

Expand All @@ -75,14 +73,10 @@ class RowSetCollection {
CHECK_EQ(elem_of_each_node_.size(), 0U);

if (row_indices_.empty()) { // edge case: empty instance set
// assign arbitrary address here, to bypass nullptr check
// (nullptr usually indicates a nonexistent rowset, but we want to
// indicate a valid rowset that happens to have zero length and occupies
// the whole instance set)
// this is okay, as BuildHist will compute (end-begin) as the set size
const size_t* begin = reinterpret_cast<size_t*>(20);
const size_t* end = begin;
elem_of_each_node_.emplace_back(Elem(begin, end, 0));
constexpr size_t* kBegin = nullptr;
constexpr size_t* kEnd = nullptr;
static_assert(kEnd - kBegin == 0, "");
elem_of_each_node_.emplace_back(Elem(kBegin, kEnd, 0));
return;
}

Expand All @@ -93,15 +87,19 @@ class RowSetCollection {

std::vector<size_t>* Data() { return &row_indices_; }
// split rowset into two
inline void AddSplit(unsigned node_id,
unsigned left_node_id,
unsigned right_node_id,
size_t n_left,
size_t n_right) {
inline void AddSplit(unsigned node_id, unsigned left_node_id, unsigned right_node_id,
size_t n_left, size_t n_right) {
const Elem e = elem_of_each_node_[node_id];
CHECK(e.begin != nullptr);
size_t* all_begin = dmlc::BeginPtr(row_indices_);
size_t* begin = all_begin + (e.begin - all_begin);

size_t* all_begin{nullptr};
size_t* begin{nullptr};
if (e.begin == nullptr) {
CHECK_EQ(n_left, 0);
CHECK_EQ(n_right, 0);
} else {
all_begin = dmlc::BeginPtr(row_indices_);
begin = all_begin + (e.begin - all_begin);
}

CHECK_EQ(n_left + n_right, e.Size());
CHECK_LE(begin + n_left, e.end);
Expand Down
3 changes: 3 additions & 0 deletions src/common/threading_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ class MemStackAllocator {
throw std::bad_alloc{};
}
}
MemStackAllocator(size_t required_size, T init) : MemStackAllocator{required_size} {
std::fill_n(ptr_, required_size_, init);
}

~MemStackAllocator() {
if (required_size_ > MaxStackSize) {
Expand Down
49 changes: 42 additions & 7 deletions src/tree/hist/evaluate_splits.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,19 +363,54 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
// The column sampler must be constructed by caller since we need to preserve the rng
// for the entire training session.
explicit HistEvaluator(TrainParam const &param, MetaInfo const &info, int32_t n_threads,
std::shared_ptr<common::ColumnSampler> sampler, ObjInfo task,
bool skip_0_index = false)
std::shared_ptr<common::ColumnSampler> sampler, ObjInfo task)
: param_{param},
column_sampler_{std::move(sampler)},
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_), GenericParameter::kCpuId},
n_threads_{n_threads},
task_{task} {
interaction_constraints_.Configure(param, info.num_col_);
column_sampler_->Init(info.num_col_, info.feature_weights.HostVector(),
param_.colsample_bynode, param_.colsample_bylevel,
param_.colsample_bytree, skip_0_index);
column_sampler_->Init(info.num_col_, info.feature_weights.HostVector(), param_.colsample_bynode,
param_.colsample_bylevel, param_.colsample_bytree);
}
};
} // namespace tree
} // namespace xgboost

/**
* \brief CPU implementation of update prediction cache, which calculates the leaf value
* for the last tree and accumulates it to prediction vector.
*
* \param p_last_tree The last tree being updated by tree updater
*/
template <typename Partitioner, typename GradientSumT, typename ExpandEntry>
void UpdatePredictionCacheImpl(GenericParameter const *ctx, RegTree const *p_last_tree,
std::vector<Partitioner> const &partitioner,
HistEvaluator<GradientSumT, ExpandEntry> const &hist_evaluator,
TrainParam const &param, linalg::VectorView<float> out_preds) {
CHECK_GT(out_preds.Size(), 0U);

CHECK(p_last_tree);
auto const &tree = *p_last_tree;
auto const &snode = hist_evaluator.Stats();
auto evaluator = hist_evaluator.Evaluator();
CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId);
size_t n_nodes = p_last_tree->GetNodes().size();
for (auto &part : partitioner) {
CHECK_EQ(part.Size(), n_nodes);
common::BlockedSpace2d space(
part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
common::ParallelFor2d(space, ctx->Threads(), [&](size_t nidx, common::Range1d r) {
if (!tree[nidx].IsDeleted() && tree[nidx].IsLeaf()) {
auto const &rowset = part[nidx];
auto const &stats = snode[nidx];
auto leaf_value =
evaluator.CalcWeight(nidx, param, GradStats{stats.stats}) * param.learning_rate;
for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
out_preds(*it) += leaf_value;
}
}
});
}
}
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
26 changes: 2 additions & 24 deletions src/tree/updater_approx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,34 +114,12 @@ class GloablApproxBuilder {
return nodes.front();
}

void UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) {
void UpdatePredictionCache(DMatrix const *data, linalg::VectorView<float> out_preds) const {
monitor_->Start(__func__);
// Caching prediction seems redundant for approx tree method, as sketching takes up
// majority of training time.
CHECK_EQ(out_preds.Size(), data->Info().num_row_);
CHECK(p_last_tree_);

size_t n_nodes = p_last_tree_->GetNodes().size();

auto evaluator = evaluator_.Evaluator();
auto const &tree = *p_last_tree_;
auto const &snode = evaluator_.Stats();
for (auto &part : partitioner_) {
CHECK_EQ(part.Size(), n_nodes);
common::BlockedSpace2d space(
part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
common::ParallelFor2d(space, ctx_->Threads(), [&](size_t nidx, common::Range1d r) {
if (tree[nidx].IsLeaf()) {
const auto rowset = part[nidx];
auto const &stats = snode.at(nidx);
auto leaf_value =
evaluator.CalcWeight(nidx, param_, GradStats{stats.stats}) * param_.learning_rate;
for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
out_preds(*it) += leaf_value;
}
}
});
}
UpdatePredictionCacheImpl(ctx_, p_last_tree_, partitioner_, evaluator_, param_, out_preds);
monitor_->Stop(__func__);
}

Expand Down
Loading

0 comments on commit 996cc70

Please sign in to comment.