Skip to content

Commit

Permalink
Requires setting leaf stat when expanding tree.
Browse files Browse the repository at this point in the history
Partly extracted from dmlc#5460.

* Fix GPU Hist feature importance.
  • Loading branch information
trivialfis committed Apr 9, 2020
1 parent ad826e9 commit 83b8e68
Show file tree
Hide file tree
Showing 11 changed files with 175 additions and 48 deletions.
46 changes: 41 additions & 5 deletions include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <cstring>
#include <algorithm>
#include <tuple>
#include <stack>

namespace xgboost {

Expand Down Expand Up @@ -88,6 +89,10 @@ struct RTreeNodeStat {
bst_float base_weight;
/*! \brief number of child that is leaf node known up to now */
int leaf_child_cnt {0};

RTreeNodeStat() = default;
RTreeNodeStat(float loss_chg, float sum_hess, float weight) :
loss_chg{loss_chg}, sum_hess{sum_hess}, base_weight{weight} {}
bool operator==(const RTreeNodeStat& b) const {
return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
base_weight == b.base_weight && leaf_child_cnt == b.leaf_child_cnt;
Expand All @@ -101,8 +106,9 @@ struct RTreeNodeStat {
class RegTree : public Model {
public:
using SplitCondT = bst_float;
static constexpr int32_t kInvalidNodeId {-1};
static constexpr bst_node_t kInvalidNodeId {-1};
static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
static constexpr bst_node_t kRoot { 0 };

/*! \brief tree node */
class Node {
Expand Down Expand Up @@ -321,6 +327,31 @@ class RegTree : public Model {
return nodes_ == b.nodes_ && stats_ == b.stats_ &&
deleted_nodes_ == b.deleted_nodes_ && param == b.param;
}
/* \brief Iterate through all nodes in this tree.
*
* \param Function that accepts a node index, and returns false when iteration should
* stop, otherwise returns true.
*/
template <typename Func> void WalkTree(Func func) const {
std::stack<bst_node_t> nodes;
nodes.push(kRoot);
auto &self = *this;
while (!nodes.empty()) {
auto nidx = nodes.top();
nodes.pop();
if (!func(nidx)) {
return;
}
auto left = self[nidx].LeftChild();
auto right = self[nidx].RightChild();
if (left != RegTree::kInvalidNodeId) {
nodes.push(left);
}
if (right != RegTree::kInvalidNodeId) {
nodes.push(right);
}
}
}
/*!
* \brief Compares whether 2 trees are equal from a user's perspective. The equality
* compares only non-deleted nodes.
Expand All @@ -347,7 +378,8 @@ class RegTree : public Model {
void ExpandNode(int nid, unsigned split_index, bst_float split_value,
bool default_left, bst_float base_weight,
bst_float left_leaf_weight, bst_float right_leaf_weight,
bst_float loss_change, float sum_hess,
bst_float loss_change, float sum_hess, float left_sum,
float right_sum,
bst_node_t leaf_right_child = kInvalidNodeId) {
int pleft = this->AllocNode();
int pright = this->AllocNode();
Expand All @@ -363,9 +395,9 @@ class RegTree : public Model {
nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child);
nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child);

this->Stat(nid).loss_chg = loss_change;
this->Stat(nid).base_weight = base_weight;
this->Stat(nid).sum_hess = sum_hess;
this->Stat(nid) = {loss_change, sum_hess, base_weight};
this->Stat(pleft) = {0.0f, left_sum, left_leaf_weight};
this->Stat(pright) = {0.0f, right_sum, right_leaf_weight};
}

/*!
Expand Down Expand Up @@ -402,6 +434,10 @@ class RegTree : public Model {
return param.num_nodes - 1 - param.num_deleted;
}

/* \brief Count number of leaves in tree. */
bst_node_t GetNumLeaves() const;
bst_node_t GetNumSplitNodes() const;

/*!
* \brief dense feature vector that can be taken by RegTree
* and can be construct from sparse feature vector.
Expand Down
52 changes: 34 additions & 18 deletions src/tree/tree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,8 @@ XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot")
return new GraphvizGenerator(fmap, attrs, with_stats);
});

constexpr bst_node_t RegTree::kRoot;

std::string RegTree::DumpModel(const FeatureMap& fmap,
bool with_stats,
std::string format) const {
Expand All @@ -623,26 +625,40 @@ bool RegTree::Equal(const RegTree& b) const {
if (NumExtraNodes() != b.NumExtraNodes()) {
return false;
}

std::stack<bst_node_t> nodes;
nodes.push(0);
auto& self = *this;
while (!nodes.empty()) {
auto nid = nodes.top();
nodes.pop();
if (!(self.nodes_.at(nid) == b.nodes_.at(nid))) {
auto const& self = *this;
bool ret { true };
this->WalkTree([&self, &b, &ret](bst_node_t nidx) {
if (!(self.nodes_.at(nidx) == b.nodes_.at(nidx))) {
ret = false;
return false;
}
auto left = self[nid].LeftChild();
auto right = self[nid].RightChild();
if (left != RegTree::kInvalidNodeId) {
nodes.push(left);
}
if (right != RegTree::kInvalidNodeId) {
nodes.push(right);
}
}
return true;
return true;
});
return ret;
}

bst_node_t RegTree::GetNumLeaves() const {
bst_node_t leaves { 0 };
auto const& self = *this;
this->WalkTree([&leaves, &self](bst_node_t nidx) {
if (self[nidx].IsLeaf()) {
leaves++;
}
return true;
});
return leaves;
}

bst_node_t RegTree::GetNumSplitNodes() const {
bst_node_t splits { 0 };
auto const& self = *this;
this->WalkTree([&splits, &self](bst_node_t nidx) {
if (!self[nidx].IsLeaf()) {
splits++;
}
return true;
});
return splits;
}

void RegTree::Load(dmlc::Stream* fi) {
Expand Down
4 changes: 3 additions & 1 deletion src/tree/updater_colmaker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,9 @@ class ColMaker: public TreeUpdater {
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
e.best.DefaultLeft(), e.weight, left_leaf_weight,
right_leaf_weight, e.best.loss_chg,
e.stats.sum_hess, 0);
e.stats.sum_hess,
e.best.left_sum.GetHess(), e.best.right_sum.GetHess(),
0);
} else {
(*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate);
}
Expand Down
3 changes: 2 additions & 1 deletion src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,8 @@ struct GPUHistMakerDevice {
tree.ExpandNode(candidate.nid, candidate.split.findex,
candidate.split.fvalue, candidate.split.dir == kLeftDir,
base_weight, left_weight, right_weight,
candidate.split.loss_chg, parent_sum.sum_hess);
candidate.split.loss_chg, parent_sum.sum_hess,
left_stats.GetHess(), right_stats.GetHess());
// Set up child constraints
node_value_constraints.resize(tree.GetNodes().size());
node_value_constraints[candidate.nid].SetChild(
Expand Down
3 changes: 2 additions & 1 deletion src/tree/updater_histmaker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ class HistMaker: public BaseMaker {
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
best.DefaultLeft(), base_weight, left_leaf_weight,
right_leaf_weight, best.loss_chg,
node_sum.sum_hess);
node_sum.sum_hess,
best.left_sum.GetHess(), best.right_sum.GetHess());
GradStats right_sum;
right_sum.SetSubstract(node_sum, left_sum[wid]);
auto left_child = (*p_tree)[nid].LeftChild();
Expand Down
6 changes: 4 additions & 2 deletions src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ void QuantileHistMaker::Builder::AddSplitsToTree(
spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate;
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
e.best.DefaultLeft(), e.weight, left_leaf_weight,
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess);
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess,
e.best.left_sum.GetHess(), e.best.right_sum.GetHess());

int left_id = (*p_tree)[nid].LeftChild();
int right_id = (*p_tree)[nid].RightChild();
Expand Down Expand Up @@ -410,7 +411,8 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate;
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
e.best.DefaultLeft(), e.weight, left_leaf_weight,
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess);
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess,
e.best.left_sum.GetHess(), e.best.right_sum.GetHess());

this->ApplySplit({candidate}, gmat, column_matrix, hist_, p_tree);

Expand Down
3 changes: 2 additions & 1 deletion src/tree/updater_skmaker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ class SketchMaker: public BaseMaker {
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
best.DefaultLeft(), base_weight, left_leaf_weight,
right_leaf_weight, best.loss_chg,
node_stats_[nid].sum_hess);
node_stats_[nid].sum_hess,
best.left_sum.GetHess(), best.right_sum.GetHess());
} else {
(*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate);
}
Expand Down
15 changes: 10 additions & 5 deletions tests/cpp/tree/test_prune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ TEST(Updater, Prune) {
pruner->Configure(cfg);

// loss_chg < min_split_loss;
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f);
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
pruner->Update(&gpair, p_dmat.get(), trees);

ASSERT_EQ(tree.NumExtraNodes(), 0);

// loss_chg > min_split_loss;
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f);
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
pruner->Update(&gpair, p_dmat.get(), trees);

ASSERT_EQ(tree.NumExtraNodes(), 2);
Expand All @@ -63,10 +65,12 @@ TEST(Updater, Prune) {
// loss_chg > min_split_loss
tree.ExpandNode(tree[0].LeftChild(),
0, 0.5f, true, 0.3, 0.4, 0.5,
/*loss_chg=*/18.0f, 0.0f);
/*loss_chg=*/18.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
tree.ExpandNode(tree[0].RightChild(),
0, 0.5f, true, 0.3, 0.4, 0.5,
/*loss_chg=*/19.0f, 0.0f);
/*loss_chg=*/19.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
cfg.emplace_back(std::make_pair("max_depth", "1"));
pruner->Configure(cfg);
pruner->Update(&gpair, p_dmat.get(), trees);
Expand All @@ -75,7 +79,8 @@ TEST(Updater, Prune) {

tree.ExpandNode(tree[0].LeftChild(),
0, 0.5f, true, 0.3, 0.4, 0.5,
/*loss_chg=*/18.0f, 0.0f);
/*loss_chg=*/18.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
cfg.emplace_back(std::make_pair("min_split_loss", "0"));
pruner->Configure(cfg);
pruner->Update(&gpair, p_dmat.get(), trees);
Expand Down
3 changes: 2 additions & 1 deletion tests/cpp/tree/test_refresh.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ TEST(Updater, Refresh) {
std::vector<RegTree*> trees {&tree};
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh", &lparam));

tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f);
tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
int cleft = tree[0].LeftChild();
int cright = tree[0].RightChild();

Expand Down
29 changes: 16 additions & 13 deletions tests/cpp/tree/test_tree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ TEST(Tree, Load) {

TEST(Tree, AllocateNode) {
RegTree tree;
tree.ExpandNode(
0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
tree.CollapseToLeaf(0, 0);
ASSERT_EQ(tree.NumExtraNodes(), 0);

tree.ExpandNode(
0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
ASSERT_EQ(tree.NumExtraNodes(), 2);

auto& nodes = tree.GetNodes();
Expand All @@ -107,18 +107,18 @@ RegTree ConstructTree() {
RegTree tree;
tree.ExpandNode(
/*nid=*/0, /*split_index=*/0, /*split_value=*/0.0f,
/*default_left=*/true,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
/*right_sum=*/0.0f);
auto left = tree[0].LeftChild();
auto right = tree[0].RightChild();
tree.ExpandNode(
/*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f,
/*default_left=*/false,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
/*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
/*right_sum=*/0.0f);
tree.ExpandNode(
/*nid=*/right, /*split_index=*/2, /*split_value=*/2.0f,
/*default_left=*/false,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
/*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
/*right_sum=*/0.0f);
return tree;
}

Expand Down Expand Up @@ -222,7 +222,8 @@ TEST(Tree, DumpDot) {

TEST(Tree, JsonIO) {
RegTree tree;
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
Json j_tree{Object()};
tree.SaveModel(&j_tree);

Expand All @@ -246,8 +247,10 @@ TEST(Tree, JsonIO) {

auto left = tree[0].LeftChild();
auto right = tree[0].RightChild();
tree.ExpandNode(left, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
tree.ExpandNode(right, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
tree.ExpandNode(left, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
tree.ExpandNode(right, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
tree.SaveModel(&j_tree);

tree.ChangeToLeaf(1, 1.0f);
Expand Down
Loading

0 comments on commit 83b8e68

Please sign in to comment.