Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require leaf statistics when expanding tree #4015

Merged
merged 2 commits into from
Jan 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,14 +303,22 @@ class RegTree {
}

/**
* \brief Expands a leaf node into two additional leaf nodes
* \brief Expands a leaf node into two additional leaf nodes.
*
* \param nid The node index to expand.
* \param split_index Feature index of the split.
* \param split_value The split condition.
* \param default_left True to default left.
* \param nid The node index to expand.
* \param split_index Feature index of the split.
* \param split_value The split condition.
* \param default_left True to default left.
* \param base_weight The base weight, before learning rate.
* \param left_leaf_weight The left leaf weight for prediction, modified by learning rate.
* \param right_leaf_weight The right leaf weight for prediction, modified by learning rate.
* \param loss_change The loss change.
* \param sum_hess The sum hess.
*/
void ExpandNode(int nid, unsigned split_index, bst_float split_value, bool default_left) {
void ExpandNode(int nid, unsigned split_index, bst_float split_value,
bool default_left, bst_float base_weight,
bst_float left_leaf_weight, bst_float right_leaf_weight,
bst_float loss_change, float sum_hess) {
int pleft = this->AllocNode();
int pright = this->AllocNode();
auto &node = nodes_[nid];
Expand All @@ -322,8 +330,12 @@ class RegTree {
node.SetSplit(split_index, split_value,
default_left);
// mark right child as 0, to indicate fresh leaf
nodes_[pleft].SetLeaf(0.0f, 0);
nodes_[pright].SetLeaf(0.0f, 0);
nodes_[pleft].SetLeaf(left_leaf_weight, 0);
nodes_[pright].SetLeaf(right_leaf_weight, 0);

this->Stat(nid).loss_chg = loss_change;
this->Stat(nid).base_weight = base_weight;
this->Stat(nid).sum_hess = sum_hess;
}

/*!
Expand Down
13 changes: 11 additions & 2 deletions src/tree/param.h
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@ struct XGBOOST_ALIGNAS(16) GradStats {
static const int kSimpleStats = 1;
/*! \brief constructor, the object must be cleared during construction */
explicit GradStats(const TrainParam& param) { this->Clear(); }
explicit GradStats(double sum_grad, double sum_hess)
: sum_grad(sum_grad), sum_hess(sum_hess) {}

template <typename GpairT>
XGBOOST_DEVICE explicit GradStats(const GpairT &sum)
Expand Down Expand Up @@ -490,8 +492,10 @@ struct SplitEntry {
bst_float loss_chg{0.0f};
/*! \brief split index */
unsigned sindex{0};
/*! \brief split value */
bst_float split_value{0.0f};
GradStats left_sum;
GradStats right_sum;

/*! \brief constructor */
SplitEntry() = default;
/*!
Expand Down Expand Up @@ -521,6 +525,8 @@ struct SplitEntry {
this->loss_chg = e.loss_chg;
this->sindex = e.sindex;
this->split_value = e.split_value;
this->left_sum = e.left_sum;
this->right_sum = e.right_sum;
return true;
} else {
return false;
Expand All @@ -535,14 +541,17 @@ struct SplitEntry {
* \return whether the proposed split is better and can replace current split
*/
inline bool Update(bst_float new_loss_chg, unsigned split_index,
bst_float new_split_value, bool default_left) {
bst_float new_split_value, bool default_left,
const GradStats &left_sum, const GradStats &right_sum) {
if (this->NeedReplace(new_loss_chg, split_index)) {
this->loss_chg = new_loss_chg;
if (default_left) {
split_index |= (1U << 31);
}
this->sindex = split_index;
this->split_value = new_split_value;
this->left_sum = left_sum;
this->right_sum = right_sum;
return true;
} else {
return false;
Expand Down
58 changes: 39 additions & 19 deletions src/tree/updater_colmaker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ class ColMaker: public TreeUpdater {
auto loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, fsplit, false);
e.best.Update(loss_chg, fid, fsplit, false, e.stats, c);
}
}
if (need_backward) {
Expand All @@ -322,7 +322,7 @@ class ColMaker: public TreeUpdater {
auto loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, tmp, c) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, fsplit, true);
e.best.Update(loss_chg, fid, fsplit, true, tmp, c);
}
}
}
Expand All @@ -335,7 +335,7 @@ class ColMaker: public TreeUpdater {
auto loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, tmp, c) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, e.last_fvalue + kRtEps, true);
e.best.Update(loss_chg, fid, e.last_fvalue + kRtEps, true, tmp, c);
}
}
}
Expand Down Expand Up @@ -368,7 +368,7 @@ class ColMaker: public TreeUpdater {
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f,
false);
false, e.stats, c);
}
}
if (need_backward) {
Expand All @@ -379,7 +379,7 @@ class ColMaker: public TreeUpdater {
auto loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, c, cright) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f, true);
e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f, true, c, cright);
}
}
}
Expand Down Expand Up @@ -410,13 +410,15 @@ class ColMaker: public TreeUpdater {
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f,
d_step == -1, c, e.stats);
} else {
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f,
d_step == -1, e.stats, c);
}
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f,
d_step == -1);
}
}
// update the statistics
Expand Down Expand Up @@ -486,18 +488,21 @@ class ColMaker: public TreeUpdater {
if (e.stats.sum_hess >= param_.min_child_weight &&
c.sum_hess >= param_.min_child_weight) {
bst_float loss_chg;
const bst_float gap = std::abs(e.last_fvalue) + kRtEps;
const bst_float delta = d_step == +1 ? gap: -gap;
if (d_step == -1) {
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, c,
e.stats);
} else {
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1,
e.stats, c);
}
const bst_float gap = std::abs(e.last_fvalue) + kRtEps;
const bst_float delta = d_step == +1 ? gap: -gap;
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1);
}
}
}
Expand Down Expand Up @@ -545,12 +550,15 @@ class ColMaker: public TreeUpdater {
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f,
d_step == -1, c, e.stats);
} else {
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f,
d_step == -1, e.stats, c);
}
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, d_step == -1);
}
}
// update the statistics
Expand All @@ -565,18 +573,21 @@ class ColMaker: public TreeUpdater {
if (e.stats.sum_hess >= param_.min_child_weight &&
c.sum_hess >= param_.min_child_weight) {
bst_float loss_chg;
GradStats left_sum;
GradStats right_sum;
if (d_step == -1) {
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
snode_[nid].root_gain);
left_sum = c;
right_sum = e.stats;
} else {
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
snode_[nid].root_gain);
left_sum = e.stats;
right_sum = c;
}
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, left_sum, right_sum) -
snode_[nid].root_gain);
const bst_float gap = std::abs(e.last_fvalue) + kRtEps;
const bst_float delta = d_step == +1 ? gap: -gap;
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1);
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, left_sum, right_sum);
}
}
}
Expand Down Expand Up @@ -637,7 +648,16 @@ class ColMaker: public TreeUpdater {
NodeEntry &e = snode_[nid];
// now we know the solution in snode[nid], set split
if (e.best.loss_chg > kRtEps) {
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft());
bst_float left_leaf_weight =
spliteval_->ComputeWeight(nid, e.best.left_sum) *
param_.learning_rate;
bst_float right_leaf_weight =
spliteval_->ComputeWeight(nid, e.best.right_sum) *
param_.learning_rate;
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
e.best.DefaultLeft(), e.weight, left_leaf_weight,
right_leaf_weight, e.best.loss_chg,
e.stats.sum_hess);
} else {
(*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate);
}
Expand Down
3 changes: 2 additions & 1 deletion src/tree/updater_gpu_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,8 @@ inline void Dense2SparseTree(RegTree* p_tree,
for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) {
const DeviceNodeStats& n = h_nodes[gpu_nid];
if (!n.IsUnused() && !n.IsLeaf()) {
tree.ExpandNode(nid, n.fidx, n.fvalue, n.dir == kLeftDir);
tree.ExpandNode(nid, n.fidx, n.fvalue, n.dir == kLeftDir, n.weight, 0.0f,
0.0f, n.root_gain, n.sum_gradients.GetHess());
tree.Stat(nid).loss_chg = n.root_gain;
tree.Stat(nid).base_weight = n.weight;
tree.Stat(nid).sum_hess = n.sum_gradients.GetHess();
Expand Down
45 changes: 19 additions & 26 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1182,42 +1182,35 @@ class GPUHistMakerSpecialised{
}

void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
// Add new leaves
RegTree& tree = *p_tree;
tree.ExpandNode(candidate.nid, candidate.split.findex, candidate.split.fvalue,
candidate.split.dir == kLeftDir);
auto& parent = tree[candidate.nid];
tree.Stat(candidate.nid).loss_chg = candidate.split.loss_chg;

// Set up child constraints
node_value_constraints_.resize(tree.GetNodes().size());
GradStats left_stats(param_);
left_stats.Add(candidate.split.left_sum);
GradStats right_stats(param_);
right_stats.Add(candidate.split.right_sum);
node_value_constraints_[candidate.nid].SetChild(
param_, parent.SplitIndex(), left_stats, right_stats,
&node_value_constraints_[parent.LeftChild()],
&node_value_constraints_[parent.RightChild()]);

// Configure left child
GradStats parent_sum(param_);
parent_sum.Add(left_stats);
parent_sum.Add(right_stats);
node_value_constraints_.resize(tree.GetNodes().size());
auto base_weight = node_value_constraints_[candidate.nid].CalcWeight(param_, parent_sum);
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
auto left_weight =
node_value_constraints_[parent.LeftChild()].CalcWeight(param_, left_stats);
tree[parent.LeftChild()].SetLeaf(left_weight * param_.learning_rate, 0);
tree.Stat(parent.LeftChild()).base_weight = left_weight;
tree.Stat(parent.LeftChild()).sum_hess = candidate.split.left_sum.GetHess();

// Configure right child
node_value_constraints_[candidate.nid].CalcWeight(param_, left_stats)*param_.learning_rate;
auto right_weight =
node_value_constraints_[parent.RightChild()].CalcWeight(param_, right_stats);
tree[parent.RightChild()].SetLeaf(right_weight * param_.learning_rate, 0);
tree.Stat(parent.RightChild()).base_weight = right_weight;
tree.Stat(parent.RightChild()).sum_hess = candidate.split.right_sum.GetHess();
node_value_constraints_[candidate.nid].CalcWeight(param_, right_stats)*param_.learning_rate;
tree.ExpandNode(candidate.nid, candidate.split.findex,
candidate.split.fvalue, candidate.split.dir == kLeftDir,
base_weight, left_weight, right_weight,
candidate.split.loss_chg, parent_sum.sum_hess);
// Set up child constraints
node_value_constraints_.resize(tree.GetNodes().size());
node_value_constraints_[candidate.nid].SetChild(
param_, tree[candidate.nid].SplitIndex(), left_stats, right_stats,
&node_value_constraints_[tree[candidate.nid].LeftChild()],
&node_value_constraints_[tree[candidate.nid].RightChild()]);

// Store sum gradients
for (auto& shard : shards_) {
shard->node_sum_gradients[parent.LeftChild()] = candidate.split.left_sum;
shard->node_sum_gradients[parent.RightChild()] = candidate.split.right_sum;
shard->node_sum_gradients[tree[candidate.nid].LeftChild()] = candidate.split.left_sum;
shard->node_sum_gradients[tree[candidate.nid].RightChild()] = candidate.split.right_sum;
}
}

Expand Down
17 changes: 14 additions & 3 deletions src/tree/updater_histmaker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ class HistMaker: public BaseMaker {
c.SetSubstract(node_sum, s);
if (c.sum_hess >= param_.min_child_weight) {
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain;
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i], false)) {
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i],
false, s, c)) {
*left_sum = s;
}
}
Expand All @@ -205,7 +206,7 @@ class HistMaker: public BaseMaker {
c.SetSubstract(node_sum, s);
if (c.sum_hess >= param_.min_child_weight) {
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain;
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i-1], true)) {
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i-1], true, c, s)) {
*left_sum = c;
}
}
Expand Down Expand Up @@ -243,8 +244,18 @@ class HistMaker: public BaseMaker {
p_tree->Stat(nid).loss_chg = best.loss_chg;
// now we know the solution in snode[nid], set split
if (best.loss_chg > kRtEps) {
bst_float base_weight = node_sum.CalcWeight(param_);
bst_float left_leaf_weight =
CalcWeight(param_, best.left_sum.sum_grad, best.left_sum.sum_hess) *
param_.learning_rate;
bst_float right_leaf_weight =
CalcWeight(param_, best.right_sum.sum_grad,
best.right_sum.sum_hess) *
param_.learning_rate;
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
best.DefaultLeft());
best.DefaultLeft(), base_weight, left_leaf_weight,
right_leaf_weight, best.loss_chg,
node_sum.sum_hess);
// right side sum
TStats right_sum;
right_sum.SetSubstract(node_sum, left_sum[wid]);
Expand Down
10 changes: 8 additions & 2 deletions src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,13 @@ void QuantileHistMaker::Builder::ApplySplit(int nid,

/* 1. Create child nodes */
NodeEntry& e = snode_[nid];
bst_float left_leaf_weight =
spliteval_->ComputeWeight(nid, e.best.left_sum) * param_.learning_rate;
bst_float right_leaf_weight =
spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate;
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
e.best.DefaultLeft());
e.best.DefaultLeft(), e.weight, left_leaf_weight,
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess);

/* 2. Categorize member rows */
const auto nthread = static_cast<bst_omp_uint>(this->nthread_);
Expand Down Expand Up @@ -698,6 +703,7 @@ void QuantileHistMaker::Builder::EnumerateSplit(int d_step,
spliteval_->ComputeSplitScore(nodeID, fid, e, c) -
snode.root_gain);
split_pt = cut_val[i];
best.Update(loss_chg, fid, split_pt, d_step == -1, e, c);
} else {
// backward enumeration: split at left bound of each bin
loss_chg = static_cast<bst_float>(
Expand All @@ -709,8 +715,8 @@ void QuantileHistMaker::Builder::EnumerateSplit(int d_step,
} else {
split_pt = cut_val[i - 1];
}
best.Update(loss_chg, fid, split_pt, d_step == -1, c, e);
}
best.Update(loss_chg, fid, split_pt, d_step == -1);
}
}
}
Expand Down
Loading