Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combine TreeModel and RegTree #3995

Merged
merged 3 commits into from
Dec 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
524 changes: 127 additions & 397 deletions include/xgboost/tree_model.h

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion src/gbm/gbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ class GBTree : public GradientBooster {
// create new tree
std::unique_ptr<RegTree> ptr(new RegTree());
ptr->param.InitAllowUnknown(this->cfg_);
ptr->InitModel();
new_trees.push_back(ptr.get());
ret->push_back(std::move(ptr));
} else if (tparam_.process_type == kUpdate) {
Expand Down
236 changes: 236 additions & 0 deletions src/tree/tree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,240 @@ std::string RegTree::DumpModel(const FeatureMap& fmap,
}
return fo.str();
}
void RegTree::FillNodeMeanValues() {
size_t num_nodes = this->param.num_nodes;
if (this->node_mean_values_.size() == num_nodes) {
return;
}
this->node_mean_values_.resize(num_nodes);
for (int root_id = 0; root_id < param.num_roots; ++root_id) {
this->FillNodeMeanValue(root_id);
}
}

bst_float RegTree::FillNodeMeanValue(int nid) {
bst_float result;
auto& node = (*this)[nid];
if (node.IsLeaf()) {
result = node.LeafValue();
} else {
result = this->FillNodeMeanValue(node.LeftChild()) * this->Stat(node.LeftChild()).sum_hess;
result += this->FillNodeMeanValue(node.RightChild()) * this->Stat(node.RightChild()).sum_hess;
result /= this->Stat(nid).sum_hess;
}
this->node_mean_values_[nid] = result;
return result;
}

void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
unsigned root_id,
bst_float *out_contribs) const {
CHECK_GT(this->node_mean_values_.size(), 0U);
// this follows the idea of http://blog.datadive.net/interpreting-random-forests/
unsigned split_index = 0;
auto pid = static_cast<int>(root_id);
// update bias value
bst_float node_value = this->node_mean_values_[pid];
out_contribs[feat.Size()] += node_value;
if ((*this)[pid].IsLeaf()) {
// nothing to do anymore
return;
}
while (!(*this)[pid].IsLeaf()) {
split_index = (*this)[pid].SplitIndex();
pid = this->GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index));
bst_float new_value = this->node_mean_values_[pid];
// update feature weight
out_contribs[split_index] += new_value - node_value;
node_value = new_value;
}
bst_float leaf_value = (*this)[pid].LeafValue();
// update leaf feature weight
out_contribs[split_index] += leaf_value - node_value;
}

// Used by TreeShap
// data we keep about our decision path
// note that pweight is included for convenience and is not tied with the other attributes
// the pweight of the i'th path element is the permuation weight of paths with i-1 ones in them
struct PathElement {
int feature_index;
bst_float zero_fraction;
bst_float one_fraction;
bst_float pweight;
PathElement() = default;
PathElement(int i, bst_float z, bst_float o, bst_float w) :
feature_index(i), zero_fraction(z), one_fraction(o), pweight(w) {}
};

// extend our decision path with a fraction of one and zero extensions
void ExtendPath(PathElement *unique_path, unsigned unique_depth,
bst_float zero_fraction, bst_float one_fraction,
int feature_index) {
unique_path[unique_depth].feature_index = feature_index;
unique_path[unique_depth].zero_fraction = zero_fraction;
unique_path[unique_depth].one_fraction = one_fraction;
unique_path[unique_depth].pweight = (unique_depth == 0 ? 1.0f : 0.0f);
for (int i = unique_depth - 1; i >= 0; i--) {
unique_path[i+1].pweight += one_fraction * unique_path[i].pweight * (i + 1)
/ static_cast<bst_float>(unique_depth + 1);
unique_path[i].pweight = zero_fraction * unique_path[i].pweight * (unique_depth - i)
/ static_cast<bst_float>(unique_depth + 1);
}
}

// undo a previous extension of the decision path
void UnwindPath(PathElement *unique_path, unsigned unique_depth,
unsigned path_index) {
const bst_float one_fraction = unique_path[path_index].one_fraction;
const bst_float zero_fraction = unique_path[path_index].zero_fraction;
bst_float next_one_portion = unique_path[unique_depth].pweight;

for (int i = unique_depth - 1; i >= 0; --i) {
if (one_fraction != 0) {
const bst_float tmp = unique_path[i].pweight;
unique_path[i].pweight = next_one_portion * (unique_depth + 1)
/ static_cast<bst_float>((i + 1) * one_fraction);
next_one_portion = tmp - unique_path[i].pweight * zero_fraction * (unique_depth - i)
/ static_cast<bst_float>(unique_depth + 1);
} else {
unique_path[i].pweight = (unique_path[i].pweight * (unique_depth + 1))
/ static_cast<bst_float>(zero_fraction * (unique_depth - i));
}
}

for (auto i = path_index; i < unique_depth; ++i) {
unique_path[i].feature_index = unique_path[i+1].feature_index;
unique_path[i].zero_fraction = unique_path[i+1].zero_fraction;
unique_path[i].one_fraction = unique_path[i+1].one_fraction;
}
}

// determine what the total permuation weight would be if
// we unwound a previous extension in the decision path
bst_float UnwoundPathSum(const PathElement *unique_path, unsigned unique_depth,
unsigned path_index) {
const bst_float one_fraction = unique_path[path_index].one_fraction;
const bst_float zero_fraction = unique_path[path_index].zero_fraction;
bst_float next_one_portion = unique_path[unique_depth].pweight;
bst_float total = 0;
for (int i = unique_depth - 1; i >= 0; --i) {
if (one_fraction != 0) {
const bst_float tmp = next_one_portion * (unique_depth + 1)
/ static_cast<bst_float>((i + 1) * one_fraction);
total += tmp;
next_one_portion = unique_path[i].pweight - tmp * zero_fraction * ((unique_depth - i)
/ static_cast<bst_float>(unique_depth + 1));
} else {
total += (unique_path[i].pweight / zero_fraction) / ((unique_depth - i)
/ static_cast<bst_float>(unique_depth + 1));
}
}
return total;
}

// recursive computation of SHAP values for a decision tree
void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
unsigned node_index, unsigned unique_depth,
PathElement *parent_unique_path,
bst_float parent_zero_fraction,
bst_float parent_one_fraction, int parent_feature_index,
int condition, unsigned condition_feature,
bst_float condition_fraction) const {
const auto node = (*this)[node_index];

// stop if we have no weight coming down to us
if (condition_fraction == 0) return;

// extend the unique path
PathElement *unique_path = parent_unique_path + unique_depth + 1;
std::copy(parent_unique_path, parent_unique_path + unique_depth + 1, unique_path);

if (condition == 0 || condition_feature != static_cast<unsigned>(parent_feature_index)) {
ExtendPath(unique_path, unique_depth, parent_zero_fraction,
parent_one_fraction, parent_feature_index);
}
const unsigned split_index = node.SplitIndex();

// leaf node
if (node.IsLeaf()) {
for (unsigned i = 1; i <= unique_depth; ++i) {
const bst_float w = UnwoundPathSum(unique_path, unique_depth, i);
const PathElement &el = unique_path[i];
phi[el.feature_index] += w * (el.one_fraction - el.zero_fraction)
* node.LeafValue() * condition_fraction;
}

// internal node
} else {
// find which branch is "hot" (meaning x would follow it)
unsigned hot_index = 0;
if (feat.IsMissing(split_index)) {
hot_index = node.DefaultChild();
} else if (feat.Fvalue(split_index) < node.SplitCond()) {
hot_index = node.LeftChild();
} else {
hot_index = node.RightChild();
}
const unsigned cold_index = (static_cast<int>(hot_index) == node.LeftChild() ?
node.RightChild() : node.LeftChild());
const bst_float w = this->Stat(node_index).sum_hess;
const bst_float hot_zero_fraction = this->Stat(hot_index).sum_hess / w;
const bst_float cold_zero_fraction = this->Stat(cold_index).sum_hess / w;
bst_float incoming_zero_fraction = 1;
bst_float incoming_one_fraction = 1;

// see if we have already split on this feature,
// if so we undo that split so we can redo it for this node
unsigned path_index = 0;
for (; path_index <= unique_depth; ++path_index) {
if (static_cast<unsigned>(unique_path[path_index].feature_index) == split_index) break;
}
if (path_index != unique_depth + 1) {
incoming_zero_fraction = unique_path[path_index].zero_fraction;
incoming_one_fraction = unique_path[path_index].one_fraction;
UnwindPath(unique_path, unique_depth, path_index);
unique_depth -= 1;
}

// divide up the condition_fraction among the recursive calls
bst_float hot_condition_fraction = condition_fraction;
bst_float cold_condition_fraction = condition_fraction;
if (condition > 0 && split_index == condition_feature) {
cold_condition_fraction = 0;
unique_depth -= 1;
} else if (condition < 0 && split_index == condition_feature) {
hot_condition_fraction *= hot_zero_fraction;
cold_condition_fraction *= cold_zero_fraction;
unique_depth -= 1;
}

TreeShap(feat, phi, hot_index, unique_depth + 1, unique_path,
hot_zero_fraction * incoming_zero_fraction, incoming_one_fraction,
split_index, condition, condition_feature, hot_condition_fraction);

TreeShap(feat, phi, cold_index, unique_depth + 1, unique_path,
cold_zero_fraction * incoming_zero_fraction, 0,
split_index, condition, condition_feature, cold_condition_fraction);
}
}

void RegTree::CalculateContributions(const RegTree::FVec &feat,
unsigned root_id, bst_float *out_contribs,
int condition,
unsigned condition_feature) const {
// find the expected value of the tree's predictions
if (condition == 0) {
bst_float node_value = this->node_mean_values_[static_cast<int>(root_id)];
out_contribs[feat.Size()] += node_value;
}

// Preallocate space for the unique path data
const int maxd = this->MaxDepth(root_id) + 2;
auto *unique_path_data = new PathElement[(maxd * (maxd + 1)) / 2];

TreeShap(feat, out_contribs, root_id, 0, unique_path_data,
1, 1, -1, condition, condition_feature, 1);
delete[] unique_path_data;
}
} // namespace xgboost
2 changes: 1 addition & 1 deletion src/tree/updater_prune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class TreePruner: public TreeUpdater {
inline int TryPruneLeaf(RegTree &tree, int nid, int depth, int npruned) { // NOLINT(*)
if (tree[nid].IsRoot()) return npruned;
int pid = tree[nid].Parent();
RegTree::NodeStat &s = tree.Stat(pid);
RTreeNodeStat &s = tree.Stat(pid);
++s.leaf_child_cnt;
if (s.leaf_child_cnt >= 2 && param_.NeedPrune(s.loss_chg, depth - 1)) {
// need to be pruned
Expand Down
1 change: 0 additions & 1 deletion tests/cpp/predictor/test_cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ TEST(cpu_predictor, Test) {

std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
trees.back()->InitModel();
(*trees.back())[0].SetLeaf(1.5f);
(*trees.back()).Stat(0).sum_hess = 1.0f;
gbm::GBTreeModel model(0.5);
Expand Down
2 changes: 0 additions & 2 deletions tests/cpp/predictor/test_gpu_predictor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ TEST(gpu_predictor, Test) {

std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::unique_ptr<RegTree>(new RegTree()));
trees.back()->InitModel();
(*trees.back())[0].SetLeaf(1.5f);
(*trees.back()).Stat(0).sum_hess = 1.0f;
gbm::GBTreeModel model(0.5);
Expand Down Expand Up @@ -181,7 +180,6 @@ TEST(gpu_predictor, MGPU_Test) {

std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::unique_ptr<RegTree>(new RegTree()));
trees.back()->InitModel();
(*trees.back())[0].SetLeaf(1.5f);
(*trees.back()).Stat(0).sum_hess = 1.0f;
gbm::GBTreeModel model(0.5);
Expand Down
3 changes: 0 additions & 3 deletions tests/cpp/tree/test_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,6 @@ TEST(GpuHist, EvaluateSplits) {
false);

RegTree tree;
tree.InitModel();

MetaInfo info;
info.num_row_ = n_rows;
info.num_col_ = n_cols;
Expand Down Expand Up @@ -338,7 +336,6 @@ TEST(GpuHist, ApplySplit) {
// Initialize GPUHistMaker
hist_maker.param_ = param;
RegTree tree;
tree.InitModel();

DeviceSplitCandidate candidate;
candidate.Update(2, kLeftDir,
Expand Down
1 change: 0 additions & 1 deletion tests/cpp/tree/test_prune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ TEST(Updater, Prune) {

// prepare tree
RegTree tree = RegTree();
tree.InitModel();
tree.param.InitAllowUnknown(cfg);
std::vector<RegTree*> trees {&tree};
// prepare pruner
Expand Down
3 changes: 0 additions & 3 deletions tests/cpp/tree/test_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ class QuantileHistMock : public QuantileHistMaker {
gmat.Init((*dmat).get(), max_bins);

RegTree tree = RegTree();
tree.InitModel();
tree.param.InitAllowUnknown(cfg);

std::vector<GradientPair> gpair =
Expand All @@ -134,7 +133,6 @@ class QuantileHistMock : public QuantileHistMaker {

void TestBuildHist() {
RegTree tree = RegTree();
tree.InitModel();
tree.param.InitAllowUnknown(cfg);

size_t constexpr max_bins = 4;
Expand All @@ -146,7 +144,6 @@ class QuantileHistMock : public QuantileHistMaker {

void TestEvaluateSplit() {
RegTree tree = RegTree();
tree.InitModel();
tree.param.InitAllowUnknown(cfg);

builder_->TestEvaluateSplit(gmatb_, tree);
Expand Down
1 change: 0 additions & 1 deletion tests/cpp/tree/test_refresh.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ TEST(Updater, Refresh) {
{"reg_lambda", "1"}};

RegTree tree = RegTree();
tree.InitModel();
tree.param.InitAllowUnknown(cfg);
std::vector<RegTree*> trees {&tree};
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh"));
Expand Down