Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept other gradient types for split entry. #5467

Merged
merged 1 commit into from
Apr 3, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 41 additions & 35 deletions src/tree/param.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,19 +286,6 @@ XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, StatT stat) {
return CalcGain(p, stat.GetGrad(), stat.GetHess());
}

// calculate cost of loss function with four statistics
template <typename TrainingParams, typename T>
XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess,
T test_grad, T test_hess) {
T w = CalcWeight(sum_grad, sum_hess);
T ret = CalcGainGivenWeight(p, test_grad, test_hess);
if (p.reg_alpha == 0.0f) {
return ret;
} else {
return ret + p.reg_alpha * std::abs(w);
}
}

// calculate weight given the statistics
template <typename TrainingParams, typename T>
XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad,
Expand Down Expand Up @@ -340,6 +327,11 @@ struct XGBOOST_ALIGNAS(16) GradStats {
XGBOOST_DEVICE double GetGrad() const { return sum_grad; }
XGBOOST_DEVICE double GetHess() const { return sum_hess; }

friend std::ostream& operator<<(std::ostream& os, GradStats s) {
os << s.GetGrad() << "/" << s.GetHess();
return os;
}

XGBOOST_DEVICE GradStats() : sum_grad{0}, sum_hess{0} {
static_assert(sizeof(GradStats) == 16,
"Size of GradStats is not 16 bytes.");
Expand Down Expand Up @@ -383,28 +375,42 @@ struct XGBOOST_ALIGNAS(16) GradStats {
* \brief statistics that is helpful to store
* and represent a split solution for the tree
*/
struct SplitEntry {
template<typename GradientT>
struct SplitEntryContainer {
/*! \brief loss change after split this node */
bst_float loss_chg {0.0f};
/*! \brief split index */
unsigned sindex{0};
bst_feature_t sindex{0};
bst_float split_value{0.0f};
GradStats left_sum;
GradStats right_sum;

/*! \brief constructor */
SplitEntry() = default;
GradientT left_sum;
GradientT right_sum;

SplitEntryContainer() = default;

friend std::ostream& operator<<(std::ostream& os, SplitEntryContainer const& s) {
os << "loss_chg: " << s.loss_chg << ", "
<< "split index: " << s.SplitIndex() << ", "
<< "split value: " << s.split_value << ", "
<< "left_sum: " << s.left_sum << ", "
<< "right_sum: " << s.right_sum;
return os;
}
/*!\return feature index to split on */
bst_feature_t SplitIndex() const { return sindex & ((1U << 31) - 1U); }
/*!\return whether missing value goes to left branch */
bool DefaultLeft() const { return (sindex >> 31) != 0; }
/*!
* \brief decides whether we can replace current entry with the given
* statistics
* This function gives better priority to lower index when loss_chg ==
* new_loss_chg.
* \brief decides whether we can replace current entry with the given statistics
*
* This function gives better priority to lower index when loss_chg == new_loss_chg.
* Not the best way, but helps to give consistent result during multi-thread
* execution.
* execution.
*
* \param new_loss_chg the loss reduction get through the split
* \param split_index the feature index where the split is on
*/
inline bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const {
bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const {
if (this->SplitIndex() <= split_index) {
return new_loss_chg > this->loss_chg;
} else {
Expand All @@ -416,7 +422,7 @@ struct SplitEntry {
* \param e candidate split solution
* \return whether the proposed split is better and can replace current split
*/
inline bool Update(const SplitEntry &e) {
inline bool Update(const SplitEntryContainer &e) {
if (this->NeedReplace(e.loss_chg, e.SplitIndex())) {
this->loss_chg = e.loss_chg;
this->sindex = e.sindex;
Expand All @@ -436,9 +442,10 @@ struct SplitEntry {
* \param default_left whether the missing value goes to left
* \return whether the proposed split is better and can replace current split
*/
inline bool Update(bst_float new_loss_chg, unsigned split_index,
bst_float new_split_value, bool default_left,
const GradStats &left_sum, const GradStats &right_sum) {
bool Update(bst_float new_loss_chg, unsigned split_index,
bst_float new_split_value, bool default_left,
const GradientT &left_sum,
const GradientT &right_sum) {
if (this->NeedReplace(new_loss_chg, split_index)) {
this->loss_chg = new_loss_chg;
if (default_left) {
Expand All @@ -453,17 +460,16 @@ struct SplitEntry {
return false;
}
}

/*! \brief same as update, used by AllReduce*/
inline static void Reduce(SplitEntry &dst, // NOLINT(*)
const SplitEntry &src) { // NOLINT(*)
inline static void Reduce(SplitEntryContainer &dst, // NOLINT(*)
const SplitEntryContainer &src) { // NOLINT(*)
dst.Update(src);
}
/*!\return feature index to split on */
inline unsigned SplitIndex() const { return sindex & ((1U << 31) - 1U); }
/*!\return whether missing value goes to left branch */
inline bool DefaultLeft() const { return (sindex >> 31) != 0; }
};

using SplitEntry = SplitEntryContainer<GradStats>;

} // namespace tree
} // namespace xgboost

Expand Down