From e0509b330734c1d6ee32b843c273a6934df24319 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 25 Feb 2020 08:32:46 +0800 Subject: [PATCH] Fix pruner. (#5335) * Honor the tree depth. * Prevent pruning pruned node. --- include/xgboost/tree_model.h | 30 ++++++++++++++++++--------- src/tree/param.h | 5 +++-- src/tree/updater_prune.cc | 36 +++++++++++++++++++-------------- tests/cpp/tree/test_prune.cc | 38 +++++++++++++++++++++++++++-------- tests/python/test_updaters.py | 24 ++++++++++++++++++++++ 5 files changed, 99 insertions(+), 34 deletions(-) diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 5c0b1caadfea..02ac19f9ae1f 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -99,9 +99,10 @@ struct RTreeNodeStat { */ class RegTree : public Model { public: - /*! \brief auxiliary statistics of node to help tree building */ using SplitCondT = bst_float; static constexpr int32_t kInvalidNodeId {-1}; + static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits::max(); + /*! \brief tree node */ class Node { public: @@ -158,7 +159,7 @@ class RegTree : public Model { } /*! \brief whether this node is deleted */ XGBOOST_DEVICE bool IsDeleted() const { - return sindex_ == std::numeric_limits::max(); + return sindex_ == kDeletedNodeMarker; } /*! \brief whether current node is root */ XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; } @@ -201,7 +202,7 @@ class RegTree : public Model { } /*! \brief mark that this node is deleted */ XGBOOST_DEVICE void MarkDelete() { - this->sindex_ = std::numeric_limits::max(); + this->sindex_ = kDeletedNodeMarker; } /*! \brief Reuse this deleted node. */ XGBOOST_DEVICE void Reuse() { @@ -534,6 +535,13 @@ class RegTree : public Model { // delete a tree node, keep the parent field to allow trace back void DeleteNode(int nid) { CHECK_GE(nid, 1); + auto pid = (*this)[nid].Parent(); + if (nid == (*this)[pid].LeftChild()) { + (*this)[pid].SetLeftChild(kInvalidNodeId); + } else { + (*this)[pid].SetRightChild(kInvalidNodeId); + } + deleted_nodes_.push_back(nid); nodes_[nid].MarkDelete(); ++param.num_deleted; @@ -548,16 +556,20 @@ inline void RegTree::FVec::Init(size_t size) { } inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) { - for (bst_uint i = 0; i < inst.size(); ++i) { - if (inst[i].index >= data_.size()) continue; - data_[inst[i].index].fvalue = inst[i].fvalue; + for (auto const& entry : inst) { + if (entry.index >= data_.size()) { + continue; + } + data_[entry.index].fvalue = entry.fvalue; } } inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) { - for (bst_uint i = 0; i < inst.size(); ++i) { - if (inst[i].index >= data_.size()) continue; - data_[inst[i].index].flag = -1; + for (auto const& entry : inst) { + if (entry.index >= data_.size()) { + continue; + } + data_[entry.index].flag = -1; } } diff --git a/src/tree/param.h b/src/tree/param.h index da63895a5ef8..213aeb14fe01 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -220,8 +220,9 @@ struct TrainParam : public XGBoostParameter { } /*! \brief given the loss change, whether we need to invoke pruning */ - inline bool NeedPrune(double loss_chg, int depth) const { - return loss_chg < this->min_split_loss; + bool NeedPrune(double loss_chg, int depth) const { + return loss_chg < this->min_split_loss || + (this->max_depth != 0 && depth > this->max_depth); } /*! \brief maximum sketch size */ inline unsigned MaxSketchSize() const { diff --git a/src/tree/updater_prune.cc b/src/tree/updater_prune.cc index e9437de93d69..a621fca46ec5 100644 --- a/src/tree/updater_prune.cc +++ b/src/tree/updater_prune.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2014 by Contributors + * Copyright 2014-2020 by Contributors * \file updater_prune.cc * \brief prune a tree given the statistics * \author Tianqi Chen @@ -10,6 +10,7 @@ #include #include +#include "xgboost/base.h" #include "xgboost/json.h" #include "./param.h" #include "../common/io.h" @@ -52,7 +53,7 @@ class TreePruner: public TreeUpdater { float lr = param_.learning_rate; param_.learning_rate = lr / trees.size(); for (auto tree : trees) { - this->DoPrune(*tree); + this->DoPrune(tree); } param_.learning_rate = lr; syncher_->Update(gpair, p_fmat, trees); @@ -60,12 +61,20 @@ class TreePruner: public TreeUpdater { private: // try to prune off current leaf - inline int TryPruneLeaf(RegTree &tree, int nid, int depth, int npruned) { // NOLINT(*) - if (tree[nid].IsRoot()) return npruned; - int pid = tree[nid].Parent(); - RTreeNodeStat &s = tree.Stat(pid); - ++s.leaf_child_cnt; - if (s.leaf_child_cnt >= 2 && param_.NeedPrune(s.loss_chg, depth - 1)) { + bst_node_t TryPruneLeaf(RegTree &tree, int nid, int depth, int npruned) { // NOLINT(*) + CHECK(tree[nid].IsLeaf()); + if (tree[nid].IsRoot()) { + return npruned; + } + bst_node_t pid = tree[nid].Parent(); + CHECK(!tree[pid].IsLeaf()); + RTreeNodeStat const &s = tree.Stat(pid); + // Only prune when both child are leaf. + auto left = tree[pid].LeftChild(); + auto right = tree[pid].RightChild(); + bool balanced = tree[left].IsLeaf() && + right != RegTree::kInvalidNodeId && tree[right].IsLeaf(); + if (balanced && param_.NeedPrune(s.loss_chg, depth)) { // need to be pruned tree.ChangeToLeaf(pid, param_.learning_rate * s.base_weight); // tail recursion @@ -75,14 +84,11 @@ class TreePruner: public TreeUpdater { } } /*! \brief do pruning of a tree */ - inline void DoPrune(RegTree &tree) { // NOLINT(*) - int npruned = 0; - // initialize auxiliary statistics - for (int nid = 0; nid < tree.param.num_nodes; ++nid) { - tree.Stat(nid).leaf_child_cnt = 0; - } + void DoPrune(RegTree* p_tree) { + auto& tree = *p_tree; + bst_node_t npruned = 0; for (int nid = 0; nid < tree.param.num_nodes; ++nid) { - if (tree[nid].IsLeaf()) { + if (tree[nid].IsLeaf() && !tree[nid].IsDeleted()) { npruned = this->TryPruneLeaf(tree, nid, tree.GetDepth(nid), npruned); } } diff --git a/tests/cpp/tree/test_prune.cc b/tests/cpp/tree/test_prune.cc index f29e42ffd99e..dd703c993471 100644 --- a/tests/cpp/tree/test_prune.cc +++ b/tests/cpp/tree/test_prune.cc @@ -1,33 +1,34 @@ /*! * Copyright 2018-2019 by Contributors */ -#include "../helpers.h" +#include #include #include +#include #include #include #include #include +#include "../helpers.h" + namespace xgboost { namespace tree { TEST(Updater, Prune) { - int constexpr kNCols = 16; + int constexpr kCols = 16; std::vector> cfg; - cfg.emplace_back(std::pair( - "num_feature", std::to_string(kNCols))); + cfg.emplace_back(std::pair("num_feature", + std::to_string(kCols))); cfg.emplace_back(std::pair( "min_split_loss", "10")); - cfg.emplace_back(std::pair( - "silent", "1")); // These data are just place holders. HostDeviceVector gpair = { {0.50f, 0.25f}, {0.50f, 0.25f}, {0.50f, 0.25f}, {0.50f, 0.25f}, {0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f} }; - auto dmat = CreateDMatrix(32, 16, 0.4, 3); + auto dmat = CreateDMatrix(32, kCols, 0.4, 3); auto lparam = CreateEmptyGenericParam(GPUIDX); @@ -57,8 +58,29 @@ TEST(Updater, Prune) { ASSERT_EQ(tree.NumExtraNodes(), 2); + // Test depth + // loss_chg > min_split_loss + tree.ExpandNode(tree[0].LeftChild(), + 0, 0.5f, true, 0.3, 0.4, 0.5, + /*loss_chg=*/18.0f, 0.0f); + tree.ExpandNode(tree[0].RightChild(), + 0, 0.5f, true, 0.3, 0.4, 0.5, + /*loss_chg=*/19.0f, 0.0f); + cfg.emplace_back(std::make_pair("max_depth", "1")); + pruner->Configure(cfg); + pruner->Update(&gpair, dmat->get(), trees); + + ASSERT_EQ(tree.NumExtraNodes(), 2); + + tree.ExpandNode(tree[0].LeftChild(), + 0, 0.5f, true, 0.3, 0.4, 0.5, + /*loss_chg=*/18.0f, 0.0f); + cfg.emplace_back(std::make_pair("min_split_loss", "0")); + pruner->Configure(cfg); + pruner->Update(&gpair, dmat->get(), trees); + ASSERT_EQ(tree.NumExtraNodes(), 2); + delete dmat; } - } // namespace tree } // namespace xgboost diff --git a/tests/python/test_updaters.py b/tests/python/test_updaters.py index 593133b75190..6dc9c77b6306 100644 --- a/tests/python/test_updaters.py +++ b/tests/python/test_updaters.py @@ -26,6 +26,30 @@ def test_colmaker(self): result = run_suite(param) assert_results_non_increasing(result, 1e-2) + @pytest.mark.skipif(**tm.no_sklearn()) + def test_pruner(self): + import sklearn + params = {'tree_method': 'exact'} + cancer = sklearn.datasets.load_breast_cancer() + X = cancer['data'] + y = cancer["target"] + + dtrain = xgb.DMatrix(X, y) + booster = xgb.train(params, dtrain=dtrain, num_boost_round=10) + grown = str(booster.get_dump()) + + params = {'updater': 'prune', 'process_type': 'update', 'gamma': '0.2'} + booster = xgb.train(params, dtrain=dtrain, num_boost_round=10, + xgb_model=booster) + after_prune = str(booster.get_dump()) + assert grown != after_prune + + booster = xgb.train(params, dtrain=dtrain, num_boost_round=10, + xgb_model=booster) + second_prune = str(booster.get_dump()) + # Second prune should not change the tree + assert after_prune == second_prune + @pytest.mark.skipif(**tm.no_sklearn()) def test_fast_histmaker(self): variable_param = {'tree_method': ['hist'],