Skip to content

Commit

Permalink
Fix rebase.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Apr 11, 2020
1 parent c3ad8b6 commit bc17174
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
5 changes: 4 additions & 1 deletion include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,8 @@ class RegTree : public Model {
std::vector<float> const& left_leaf_weight,
std::vector<float> const& right_leaf_weight,
bst_float loss_change,
std::vector<double> const& sum_hess) {
std::vector<double> const& sum_hess,
std::vector<double> const& left_sum, std::vector<double> const& right_sum) {
int pleft = this->AllocNode();
int pright = this->AllocNode();
auto &node = nodes_[nid];
Expand All @@ -477,6 +478,8 @@ class RegTree : public Model {
this->multi_target_stats_.Set(nid, loss_change,
common::Span<float const>{base_weight},
common::Span<double const>{sum_hess});
this->multi_target_stats_.Set(pleft, 0, {left_leaf_weight}, {left_sum});
this->multi_target_stats_.Set(pright, 0, {right_leaf_weight}, {right_sum});
}

/*!
Expand Down
8 changes: 3 additions & 5 deletions src/tree/updater_exact.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,9 @@ size_t MultiExact<GradientT>::ExpandTree(RegTree *p_tree,
(left_weight * param_.learning_rate).vec,
(right_weight * param_.learning_rate).vec,
split.candidate.loss_chg,
split.parent_sum.GetHess().vec);
split.parent_sum.GetHess().vec,
split.candidate.left_sum.GetHess().vec,
split.candidate.right_sum.GetHess().vec);
auto left = tree[split.nidx].LeftChild();
auto right = tree[split.nidx].RightChild();
interaction_constraints_.Split(split.nidx, split.candidate.SplitIndex(), left, right);
Expand All @@ -304,8 +306,6 @@ size_t MultiExact<GradientT>::ExpandTree(RegTree *p_tree,
CHECK_EQ(is_splitable_[left], 1);
max_node = std::max(max_node, static_cast<size_t>(left));
} else {
tree.SetLeaf((left_weight * param_.learning_rate).vec, left,
split.candidate.left_sum.GetHess().vec);
is_splitable_[left] = 0;
}
if (SplitEntry::ChildIsValid(tree.GetDepth(right), leaves, param_)) {
Expand All @@ -317,8 +317,6 @@ size_t MultiExact<GradientT>::ExpandTree(RegTree *p_tree,
CHECK_EQ(is_splitable_[right], 1);
max_node = std::max(max_node, static_cast<size_t>(right));
} else {
tree.SetLeaf((right_weight * param_.learning_rate).vec, right,
split.candidate.right_sum.GetHess().vec);
is_splitable_[right] = 0;
}
}
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/tree/test_exact.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class MultiExactTest : public :: testing::Test {
void SetUp() override {
gradients_ = GenerateRandomGradients(kRows * kLabels, -1.0f, 1.0f);
auto h_grad = common::Span<GradientPair>{gradients_.HostVector()};
p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatix(true);
p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatrix(true);
p_dmat_->Info().labels_.Resize(kRows);

auto &h_labels = p_dmat_->Info().labels_.HostVector();
Expand Down

0 comments on commit bc17174

Please sign in to comment.