From 8ab4cbfbeb354664e288069eae8a7a13ae8ba78c Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 1 Apr 2020 18:24:07 +0800 Subject: [PATCH] Accept other gradient types for split entry. --- src/tree/param.h | 76 ++++++++++++++++++++++++++---------------------- 1 file changed, 41 insertions(+), 35 deletions(-) diff --git a/src/tree/param.h b/src/tree/param.h index 213aeb14fe01..97ae94cda83b 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -286,19 +286,6 @@ XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, StatT stat) { return CalcGain(p, stat.GetGrad(), stat.GetHess()); } -// calculate cost of loss function with four statistics -template -XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess, - T test_grad, T test_hess) { - T w = CalcWeight(sum_grad, sum_hess); - T ret = CalcGainGivenWeight(p, test_grad, test_hess); - if (p.reg_alpha == 0.0f) { - return ret; - } else { - return ret + p.reg_alpha * std::abs(w); - } -} - // calculate weight given the statistics template XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad, @@ -340,6 +327,11 @@ struct XGBOOST_ALIGNAS(16) GradStats { XGBOOST_DEVICE double GetGrad() const { return sum_grad; } XGBOOST_DEVICE double GetHess() const { return sum_hess; } + friend std::ostream& operator<<(std::ostream& os, GradStats s) { + os << s.GetGrad() << "/" << s.GetHess(); + return os; + } + XGBOOST_DEVICE GradStats() : sum_grad{0}, sum_hess{0} { static_assert(sizeof(GradStats) == 16, "Size of GradStats is not 16 bytes."); @@ -383,28 +375,42 @@ struct XGBOOST_ALIGNAS(16) GradStats { * \brief statistics that is helpful to store * and represent a split solution for the tree */ -struct SplitEntry { +template +struct SplitEntryContainer { /*! \brief loss change after split this node */ bst_float loss_chg {0.0f}; /*! \brief split index */ - unsigned sindex{0}; + bst_feature_t sindex{0}; bst_float split_value{0.0f}; - GradStats left_sum; - GradStats right_sum; - /*! \brief constructor */ - SplitEntry() = default; + GradientT left_sum; + GradientT right_sum; + + SplitEntryContainer() = default; + + friend std::ostream& operator<<(std::ostream& os, SplitEntryContainer const& s) { + os << "loss_chg: " << s.loss_chg << ", " + << "split index: " << s.SplitIndex() << ", " + << "split value: " << s.split_value << ", " + << "left_sum: " << s.left_sum << ", " + << "right_sum: " << s.right_sum; + return os; + } + /*!\return feature index to split on */ + bst_feature_t SplitIndex() const { return sindex & ((1U << 31) - 1U); } + /*!\return whether missing value goes to left branch */ + bool DefaultLeft() const { return (sindex >> 31) != 0; } /*! - * \brief decides whether we can replace current entry with the given - * statistics - * This function gives better priority to lower index when loss_chg == - * new_loss_chg. + * \brief decides whether we can replace current entry with the given statistics + * + * This function gives better priority to lower index when loss_chg == new_loss_chg. * Not the best way, but helps to give consistent result during multi-thread - * execution. + * execution. + * * \param new_loss_chg the loss reduction get through the split * \param split_index the feature index where the split is on */ - inline bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const { + bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const { if (this->SplitIndex() <= split_index) { return new_loss_chg > this->loss_chg; } else { @@ -416,7 +422,7 @@ struct SplitEntry { * \param e candidate split solution * \return whether the proposed split is better and can replace current split */ - inline bool Update(const SplitEntry &e) { + inline bool Update(const SplitEntryContainer &e) { if (this->NeedReplace(e.loss_chg, e.SplitIndex())) { this->loss_chg = e.loss_chg; this->sindex = e.sindex; @@ -436,9 +442,10 @@ struct SplitEntry { * \param default_left whether the missing value goes to left * \return whether the proposed split is better and can replace current split */ - inline bool Update(bst_float new_loss_chg, unsigned split_index, - bst_float new_split_value, bool default_left, - const GradStats &left_sum, const GradStats &right_sum) { + bool Update(bst_float new_loss_chg, unsigned split_index, + bst_float new_split_value, bool default_left, + const GradientT &left_sum, + const GradientT &right_sum) { if (this->NeedReplace(new_loss_chg, split_index)) { this->loss_chg = new_loss_chg; if (default_left) { @@ -453,17 +460,16 @@ struct SplitEntry { return false; } } + /*! \brief same as update, used by AllReduce*/ - inline static void Reduce(SplitEntry &dst, // NOLINT(*) - const SplitEntry &src) { // NOLINT(*) + inline static void Reduce(SplitEntryContainer &dst, // NOLINT(*) + const SplitEntryContainer &src) { // NOLINT(*) dst.Update(src); } - /*!\return feature index to split on */ - inline unsigned SplitIndex() const { return sindex & ((1U << 31) - 1U); } - /*!\return whether missing value goes to left branch */ - inline bool DefaultLeft() const { return (sindex >> 31) != 0; } }; +using SplitEntry = SplitEntryContainer; + } // namespace tree } // namespace xgboost