Skip to content

Commit

Permalink
Cache left and right gradient sums
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Dec 21, 2018
1 parent f75a21a commit 9537a08
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 27 deletions.
13 changes: 11 additions & 2 deletions src/tree/param.h
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@ struct XGBOOST_ALIGNAS(16) GradStats {
static const int kSimpleStats = 1;
/*! \brief constructor, the object must be cleared during construction */
explicit GradStats(const TrainParam& param) { this->Clear(); }
explicit GradStats(double sum_grad, double sum_hess)
: sum_grad(sum_grad), sum_hess(sum_hess) {}

template <typename GpairT>
XGBOOST_DEVICE explicit GradStats(const GpairT &sum)
Expand Down Expand Up @@ -490,8 +492,10 @@ struct SplitEntry {
bst_float loss_chg{0.0f};
/*! \brief split index */
unsigned sindex{0};
/*! \brief split value */
bst_float split_value{0.0f};
GradStats left_sum;
GradStats right_sum;

/*! \brief constructor */
SplitEntry() = default;
/*!
Expand Down Expand Up @@ -521,6 +525,8 @@ struct SplitEntry {
this->loss_chg = e.loss_chg;
this->sindex = e.sindex;
this->split_value = e.split_value;
this->left_sum = e.left_sum;
this->right_sum = e.right_sum;
return true;
} else {
return false;
Expand All @@ -535,14 +541,17 @@ struct SplitEntry {
* \return whether the proposed split is better and can replace current split
*/
inline bool Update(bst_float new_loss_chg, unsigned split_index,
bst_float new_split_value, bool default_left) {
bst_float new_split_value, bool default_left,
const GradStats &left_sum, const GradStats &right_sum) {
if (this->NeedReplace(new_loss_chg, split_index)) {
this->loss_chg = new_loss_chg;
if (default_left) {
split_index |= (1U << 31);
}
this->sindex = split_index;
this->split_value = new_split_value;
this->left_sum = left_sum;
this->right_sum = right_sum;
return true;
} else {
return false;
Expand Down
33 changes: 18 additions & 15 deletions src/tree/updater_colmaker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ class ColMaker: public TreeUpdater {
auto loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, fsplit, false);
e.best.Update(loss_chg, fid, fsplit, false, e.stats, c);
}
}
if (need_backward) {
Expand All @@ -322,7 +322,7 @@ class ColMaker: public TreeUpdater {
auto loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, tmp, c) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, fsplit, true);
e.best.Update(loss_chg, fid, fsplit, true, tmp, c);
}
}
}
Expand All @@ -335,7 +335,7 @@ class ColMaker: public TreeUpdater {
auto loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, tmp, c) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, e.last_fvalue + kRtEps, true);
e.best.Update(loss_chg, fid, e.last_fvalue + kRtEps, true, tmp, c);
}
}
}
Expand Down Expand Up @@ -368,7 +368,7 @@ class ColMaker: public TreeUpdater {
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f,
false);
false, e.stats, c);
}
}
if (need_backward) {
Expand All @@ -379,7 +379,7 @@ class ColMaker: public TreeUpdater {
auto loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, c, cright) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f, true);
e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f, true, c, cright);
}
}
}
Expand Down Expand Up @@ -416,7 +416,7 @@ class ColMaker: public TreeUpdater {
snode_[nid].root_gain);
}
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f,
d_step == -1);
d_step == -1, e.stats, c);
}
}
// update the statistics
Expand Down Expand Up @@ -497,7 +497,7 @@ class ColMaker: public TreeUpdater {
}
const bst_float gap = std::abs(e.last_fvalue) + kRtEps;
const bst_float delta = d_step == +1 ? gap: -gap;
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1);
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, e.stats, c);
}
}
}
Expand Down Expand Up @@ -550,7 +550,7 @@ class ColMaker: public TreeUpdater {
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
snode_[nid].root_gain);
}
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, d_step == -1);
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, d_step == -1, e.stats, c);
}
}
// update the statistics
Expand All @@ -565,18 +565,21 @@ class ColMaker: public TreeUpdater {
if (e.stats.sum_hess >= param_.min_child_weight &&
c.sum_hess >= param_.min_child_weight) {
bst_float loss_chg;
GradStats left_sum;
GradStats right_sum;
if (d_step == -1) {
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
snode_[nid].root_gain);
left_sum = c;
right_sum = e.stats;
} else {
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
snode_[nid].root_gain);
left_sum = e.stats;
right_sum = c;
}
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, left_sum, right_sum) -
snode_[nid].root_gain);
const bst_float gap = std::abs(e.last_fvalue) + kRtEps;
const bst_float delta = d_step == +1 ? gap: -gap;
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1);
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, left_sum, right_sum);
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/tree/updater_histmaker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ class HistMaker: public BaseMaker {
c.SetSubstract(node_sum, s);
if (c.sum_hess >= param_.min_child_weight) {
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain;
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i], false)) {
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i],
false, s, c)) {
*left_sum = s;
}
}
Expand All @@ -205,7 +206,7 @@ class HistMaker: public BaseMaker {
c.SetSubstract(node_sum, s);
if (c.sum_hess >= param_.min_child_weight) {
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain;
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i-1], true)) {
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i-1], true, s, c)) {
*left_sum = c;
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,7 @@ void QuantileHistMaker::Builder::EnumerateSplit(int d_step,
spliteval_->ComputeSplitScore(nodeID, fid, e, c) -
snode.root_gain);
split_pt = cut_val[i];
best.Update(loss_chg, fid, split_pt, d_step == -1, e, c);
} else {
// backward enumeration: split at left bound of each bin
loss_chg = static_cast<bst_float>(
Expand All @@ -709,8 +710,8 @@ void QuantileHistMaker::Builder::EnumerateSplit(int d_step,
} else {
split_pt = cut_val[i - 1];
}
best.Update(loss_chg, fid, split_pt, d_step == -1, c, e);
}
best.Update(loss_chg, fid, split_pt, d_step == -1);
}
}
}
Expand Down
14 changes: 10 additions & 4 deletions src/tree/updater_skmaker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -336,15 +336,19 @@ class SketchMaker: public BaseMaker {
if (s.sum_hess >= param_.min_child_weight &&
c.sum_hess >= param_.min_child_weight) {
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain;
best->Update(static_cast<bst_float>(loss_chg), fid, fsplits[i], false);
best->Update(static_cast<bst_float>(loss_chg), fid, fsplits[i], false,
GradStats(s.pos_grad - s.neg_grad , s.sum_hess),
GradStats(c.pos_grad - c.neg_grad, c.sum_hess));
}
// backward
c.SetSubstract(feat_sum, s);
s.SetSubstract(node_sum, c);
if (s.sum_hess >= param_.min_child_weight &&
c.sum_hess >= param_.min_child_weight) {
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain;
best->Update(static_cast<bst_float>(loss_chg), fid, fsplits[i], true);
best->Update(static_cast<bst_float>(loss_chg), fid, fsplits[i], true,
GradStats(s.pos_grad - s.neg_grad, s.sum_hess),
GradStats(c.pos_grad - c.neg_grad, c.sum_hess));
}
}
{
Expand All @@ -355,8 +359,10 @@ class SketchMaker: public BaseMaker {
c.sum_hess >= param_.min_child_weight) {
bst_float cpt = fsplits.back();
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain;
best->Update(static_cast<bst_float>(loss_chg),
fid, cpt + std::abs(cpt) + 1.0f, false);
best->Update(static_cast<bst_float>(loss_chg), fid,
cpt + std::abs(cpt) + 1.0f, false,
GradStats(s.pos_grad - s.neg_grad, s.sum_hess),
GradStats(c.pos_grad - c.neg_grad, c.sum_hess));
}
}
}
Expand Down
9 changes: 6 additions & 3 deletions tests/cpp/tree/test_param.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,15 @@ TEST(Param, SplitEntry) {

xgboost::tree::SplitEntry se2;
EXPECT_FALSE(se1.Update(se2));
EXPECT_FALSE(se2.Update(-1, 100, 0, true));
ASSERT_TRUE(se2.Update(1, 100, 0, true));
EXPECT_FALSE(se2.Update(-1, 100, 0, true, xgboost::tree::GradStats(),
xgboost::tree::GradStats()));
ASSERT_TRUE(se2.Update(1, 100, 0, true, xgboost::tree::GradStats(),
xgboost::tree::GradStats()));
ASSERT_TRUE(se1.Update(se2));

xgboost::tree::SplitEntry se3;
se3.Update(2, 101, 0, false);
se3.Update(2, 101, 0, false, xgboost::tree::GradStats(),
xgboost::tree::GradStats());
xgboost::tree::SplitEntry::Reduce(se2, se3);
EXPECT_EQ(se2.SplitIndex(), 101);
EXPECT_FALSE(se2.DefaultLeft());
Expand Down

0 comments on commit 9537a08

Please sign in to comment.