From 84c99f86f46b1ff901b9416233ae5eb172c18755 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Wed, 19 Dec 2018 12:16:40 +1300 Subject: [PATCH] Combine TreeModel and RegTree (#3995) --- include/xgboost/tree_model.h | 524 ++++++---------------- src/gbm/gbtree.cc | 1 - src/tree/tree_model.cc | 236 ++++++++++ src/tree/updater_prune.cc | 2 +- tests/cpp/predictor/test_cpu_predictor.cc | 1 - tests/cpp/predictor/test_gpu_predictor.cu | 2 - tests/cpp/tree/test_gpu_hist.cu | 3 - tests/cpp/tree/test_prune.cc | 1 - tests/cpp/tree/test_quantile_hist.cc | 3 - tests/cpp/tree/test_refresh.cc | 1 - 10 files changed, 364 insertions(+), 410 deletions(-) diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 554c7cd92dc8..56e2820e8c9d 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -22,6 +22,8 @@ namespace xgboost { +struct PathElement; // forward declaration + /*! \brief meta parameters of the tree */ struct TreeParam : public dmlc::Parameter { /*! \brief number of start root */ @@ -62,18 +64,26 @@ struct TreeParam : public dmlc::Parameter { } }; +/*! \brief node statistics used in regression tree */ +struct RTreeNodeStat { + /*! \brief loss change caused by current split */ + bst_float loss_chg; + /*! \brief sum of hessian values, used to measure coverage of data */ + bst_float sum_hess; + /*! \brief weight of current node */ + bst_float base_weight; + /*! \brief number of child that is leaf node known up to now */ + int leaf_child_cnt; +}; + /*! - * \brief template class of TreeModel - * \tparam TSplitCond data type to indicate split condition - * \tparam TNodeStat auxiliary statistics of node to help tree building + * \brief define regression tree to be the most common tree model. + * This is the data structure used in xgboost's major tree models. */ -template -class TreeModel { +class RegTree { public: - /*! \brief data type to indicate split condition */ - using NodeStat = TNodeStat; /*! \brief auxiliary statistics of node to help tree building */ - using SplitCond = TSplitCond; + using SplitCondT = bst_float; /*! \brief tree node */ class Node { public: @@ -83,58 +93,65 @@ class TreeModel { "Node: 64 bit align"); } /*! \brief index of left child */ - inline int LeftChild() const { + int LeftChild() const { return this->cleft_; } /*! \brief index of right child */ - inline int RightChild() const { + int RightChild() const { return this->cright_; } /*! \brief index of default child when feature is missing */ - inline int DefaultChild() const { + int DefaultChild() const { return this->DefaultLeft() ? this->LeftChild() : this->RightChild(); } /*! \brief feature index of split condition */ - inline unsigned SplitIndex() const { + unsigned SplitIndex() const { return sindex_ & ((1U << 31) - 1U); } /*! \brief when feature is unknown, whether goes to left child */ - inline bool DefaultLeft() const { + bool DefaultLeft() const { return (sindex_ >> 31) != 0; } /*! \brief whether current node is leaf node */ - inline bool IsLeaf() const { + bool IsLeaf() const { return cleft_ == -1; } /*! \return get leaf value of leaf node */ - inline bst_float LeafValue() const { + bst_float LeafValue() const { return (this->info_).leaf_value; } /*! \return get split condition of the node */ - inline TSplitCond SplitCond() const { + SplitCondT SplitCond() const { return (this->info_).split_cond; } /*! \brief get parent of the node */ - inline int Parent() const { + int Parent() const { return parent_ & ((1U << 31) - 1); } /*! \brief whether current node is left child */ - inline bool IsLeftChild() const { + bool IsLeftChild() const { return (parent_ & (1U << 31)) != 0; } /*! \brief whether this node is deleted */ - inline bool IsDeleted() const { + bool IsDeleted() const { return sindex_ == std::numeric_limits::max(); } /*! \brief whether current node is root */ - inline bool IsRoot() const { + bool IsRoot() const { return parent_ == -1; } + /*! + * \brief set the left child + * \param nid node id to right child + */ + void SetLeftChild(int nid) { + this->cleft_ = nid; + } /*! * \brief set the right child * \param nid node id to right child */ - inline void SetRightChild(int nid) { + void SetRightChild(int nid) { this->cright_ = nid; } /*! @@ -143,7 +160,7 @@ class TreeModel { * \param split_cond split condition * \param default_left the default direction when feature is unknown */ - inline void SetSplit(unsigned split_index, TSplitCond split_cond, + void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left = false) { if (default_left) split_index |= (1U << 31); this->sindex_ = split_index; @@ -155,25 +172,29 @@ class TreeModel { * \param right right index, could be used to store * additional information */ - inline void SetLeaf(bst_float value, int right = -1) { + void SetLeaf(bst_float value, int right = -1) { (this->info_).leaf_value = value; this->cleft_ = -1; this->cright_ = right; } /*! \brief mark that this node is deleted */ - inline void MarkDelete() { + void MarkDelete() { this->sindex_ = std::numeric_limits::max(); } + // set parent + void SetParent(int pidx, bool is_left_child = true) { + if (is_left_child) pidx |= (1U << 31); + this->parent_ = pidx; + } private: - friend class TreeModel; /*! * \brief in leaf node, we have weights, in non-leaf nodes, * we have split condition */ union Info{ bst_float leaf_value; - TSplitCond split_cond; + SplitCondT split_cond; }; // pointer to parent, highest bit is used to // indicate whether it's a left child or not @@ -184,51 +205,14 @@ class TreeModel { unsigned sindex_{0}; // extra info Info info_; - // set parent - inline void SetParent(int pidx, bool is_left_child = true) { - if (is_left_child) pidx |= (1U << 31); - this->parent_ = pidx; - } }; - protected: - // vector of nodes - std::vector nodes_; - // free node space, used during training process - std::vector deleted_nodes_; - // stats of nodes - std::vector stats_; - // allocate a new node, - // !!!!!! NOTE: may cause BUG here, nodes.resize - inline int AllocNode() { - if (param.num_deleted != 0) { - int nd = deleted_nodes_.back(); - deleted_nodes_.pop_back(); - --param.num_deleted; - return nd; - } - int nd = param.num_nodes++; - CHECK_LT(param.num_nodes, std::numeric_limits::max()) - << "number of nodes in the tree exceed 2^31"; - nodes_.resize(param.num_nodes); - stats_.resize(param.num_nodes); - return nd; - } - // delete a tree node, keep the parent field to allow trace back - inline void DeleteNode(int nid) { - CHECK_GE(nid, param.num_roots); - deleted_nodes_.push_back(nid); - nodes_[nid].MarkDelete(); - ++param.num_deleted; - } - - public: /*! * \brief change a non leaf node to a leaf node, delete its children * \param rid node id of the node * \param value new leaf value */ - inline void ChangeToLeaf(int rid, bst_float value) { + void ChangeToLeaf(int rid, bst_float value) { CHECK(nodes_[nodes_[rid].LeftChild() ].IsLeaf()); CHECK(nodes_[nodes_[rid].RightChild()].IsLeaf()); this->DeleteNode(nodes_[rid].LeftChild()); @@ -240,7 +224,7 @@ class TreeModel { * \param rid node id of the node * \param value new leaf value */ - inline void CollapseToLeaf(int rid, bst_float value) { + void CollapseToLeaf(int rid, bst_float value) { if (nodes_[rid].IsLeaf()) return; if (!nodes_[nodes_[rid].LeftChild() ].IsLeaf()) { CollapseToLeaf(nodes_[rid].LeftChild(), 0.0f); @@ -251,59 +235,53 @@ class TreeModel { this->ChangeToLeaf(rid, value); } - public: /*! \brief model parameter */ TreeParam param; /*! \brief constructor */ - TreeModel() { + RegTree() { param.num_nodes = 1; param.num_roots = 1; param.num_deleted = 0; - nodes_.resize(1); + nodes_.resize(param.num_nodes); + stats_.resize(param.num_nodes); + for (int i = 0; i < param.num_nodes; i ++) { + nodes_[i].SetLeaf(0.0f); + nodes_[i].SetParent(-1); + } } /*! \brief get node given nid */ - inline Node& operator[](int nid) { + Node& operator[](int nid) { return nodes_[nid]; } /*! \brief get node given nid */ - inline const Node& operator[](int nid) const { + const Node& operator[](int nid) const { return nodes_[nid]; } /*! \brief get const reference to nodes */ - inline const std::vector& GetNodes() const { return nodes_; } + const std::vector& GetNodes() const { return nodes_; } /*! \brief get node statistics given nid */ - inline NodeStat& Stat(int nid) { + RTreeNodeStat& Stat(int nid) { return stats_[nid]; } /*! \brief get node statistics given nid */ - inline const NodeStat& Stat(int nid) const { + const RTreeNodeStat& Stat(int nid) const { return stats_[nid]; } - /*! \brief initialize the model */ - inline void InitModel() { - param.num_nodes = param.num_roots; - nodes_.resize(param.num_nodes); - stats_.resize(param.num_nodes); - for (int i = 0; i < param.num_nodes; i ++) { - nodes_[i].SetLeaf(0.0f); - nodes_[i].SetParent(-1); - } - } /*! * \brief load model from stream * \param fi input stream */ - inline void Load(dmlc::Stream* fi) { + void Load(dmlc::Stream* fi) { CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam)); nodes_.resize(param.num_nodes); stats_.resize(param.num_nodes); CHECK_NE(param.num_nodes, 0); CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size()), sizeof(Node) * nodes_.size()); - CHECK_EQ(fi->Read(dmlc::BeginPtr(stats_), sizeof(NodeStat) * stats_.size()), - sizeof(NodeStat) * stats_.size()); + CHECK_EQ(fi->Read(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * stats_.size()), + sizeof(RTreeNodeStat) * stats_.size()); // chg deleted nodes deleted_nodes_.resize(0); for (int i = param.num_roots; i < param.num_nodes; ++i) { @@ -315,44 +293,34 @@ class TreeModel { * \brief save model to stream * \param fo output stream */ - inline void Save(dmlc::Stream* fo) const { + void Save(dmlc::Stream* fo) const { CHECK_EQ(param.num_nodes, static_cast(nodes_.size())); CHECK_EQ(param.num_nodes, static_cast(stats_.size())); fo->Write(¶m, sizeof(TreeParam)); CHECK_NE(param.num_nodes, 0); fo->Write(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size()); - fo->Write(dmlc::BeginPtr(stats_), sizeof(NodeStat) * nodes_.size()); + fo->Write(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * nodes_.size()); } /*! * \brief add child nodes to node * \param nid node id to add children to */ - inline void AddChilds(int nid) { - int pleft = this->AllocNode(); + void AddChilds(int nid) { + int pleft = this->AllocNode(); int pright = this->AllocNode(); - nodes_[nid].cleft_ = pleft; - nodes_[nid].cright_ = pright; + nodes_[nid].SetLeftChild(pleft); + nodes_[nid].SetRightChild(pright); nodes_[nodes_[nid].LeftChild() ].SetParent(nid, true); nodes_[nodes_[nid].RightChild()].SetParent(nid, false); } - /*! - * \brief only add a right child to a leaf node - * \param nid node id to add right child - */ - inline void AddRightChild(int nid) { - int pright = this->AllocNode(); - nodes_[nid].right = pright; - nodes_[nodes_[nid].right].SetParent(nid, false); - } /*! * \brief get current depth * \param nid node id - * \param pass_rchild whether right child is not counted in depth */ - inline int GetDepth(int nid, bool pass_rchild = false) const { + int GetDepth(int nid) const { int depth = 0; while (!nodes_[nid].IsRoot()) { - if (!pass_rchild || nodes_[nid].IsLeftChild()) ++depth; + ++depth; nid = nodes_[nid].Parent(); } return depth; @@ -361,97 +329,65 @@ class TreeModel { * \brief get maximum depth * \param nid node id */ - inline int MaxDepth(int nid) const { + int MaxDepth(int nid) const { if (nodes_[nid].IsLeaf()) return 0; return std::max(MaxDepth(nodes_[nid].LeftChild())+1, MaxDepth(nodes_[nid].RightChild())+1); } + /*! * \brief get maximum depth */ - inline int MaxDepth() { + int MaxDepth() { int maxd = 0; for (int i = 0; i < param.num_roots; ++i) { maxd = std::max(maxd, MaxDepth(i)); } return maxd; } + /*! \brief number of extra nodes besides the root */ - inline int NumExtraNodes() const { + int NumExtraNodes() const { return param.num_nodes - param.num_roots - param.num_deleted; } -}; -/*! \brief node statistics used in regression tree */ -struct RTreeNodeStat { - /*! \brief loss change caused by current split */ - bst_float loss_chg; - /*! \brief sum of hessian values, used to measure coverage of data */ - bst_float sum_hess; - /*! \brief weight of current node */ - bst_float base_weight; - /*! \brief number of child that is leaf node known up to now */ - int leaf_child_cnt; -}; - -// Used by TreeShap -// data we keep about our decision path -// note that pweight is included for convenience and is not tied with the other attributes -// the pweight of the i'th path element is the permuation weight of paths with i-1 ones in them -struct PathElement { - int feature_index; - bst_float zero_fraction; - bst_float one_fraction; - bst_float pweight; - PathElement() = default; - PathElement(int i, bst_float z, bst_float o, bst_float w) : - feature_index(i), zero_fraction(z), one_fraction(o), pweight(w) {} -}; - -/*! - * \brief define regression tree to be the most common tree model. - * This is the data structure used in xgboost's major tree models. - */ -class RegTree: public TreeModel { - public: /*! * \brief dense feature vector that can be taken by RegTree * and can be construct from sparse feature vector. */ struct FVec { - public: /*! * \brief initialize the vector with size vector * \param size The size of the feature vector. */ - inline void Init(size_t size); + void Init(size_t size); /*! * \brief fill the vector with sparse vector * \param inst The sparse instance to fill. */ - inline void Fill(const SparsePage::Inst& inst); + void Fill(const SparsePage::Inst& inst); /*! * \brief drop the trace after fill, must be called after fill. * \param inst The sparse instance to drop. */ - inline void Drop(const SparsePage::Inst& inst); + void Drop(const SparsePage::Inst& inst); /*! * \brief returns the size of the feature vector * \return the size of the feature vector */ - inline size_t Size() const; + size_t Size() const; /*! * \brief get ith value * \param i feature index. * \return the i-th feature value */ - inline bst_float Fvalue(size_t i) const; + bst_float Fvalue(size_t i) const; /*! * \brief check whether i-th entry is missing * \param i feature index. * \return whether i-th value is missing. */ - inline bool IsMissing(size_t i) const; + bool IsMissing(size_t i) const; private: /*! @@ -470,14 +406,7 @@ class RegTree: public TreeModel { * \param root_id starting root index of the instance * \return the leaf index of the given feature */ - inline int GetLeafIndex(const FVec& feat, unsigned root_id = 0) const; - /*! - * \brief get the prediction of regression tree, only accepts dense feature vector - * \param feat dense feature vector, if the feature is missing the field is set to NaN - * \param root_id starting root index of the instance - * \return the leaf index of the given feature - */ - inline bst_float Predict(const FVec& feat, unsigned root_id = 0) const; + int GetLeafIndex(const FVec& feat, unsigned root_id = 0) const; /*! * \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree * \param feat dense feature vector, if the feature is missing the field is set to NaN @@ -486,10 +415,9 @@ class RegTree: public TreeModel { * \param condition fix one feature to either off (-1) on (1) or not fixed (0 default) * \param condition_feature the index of the feature to fix */ - inline void CalculateContributions(const RegTree::FVec& feat, unsigned root_id, - bst_float *out_contribs, - int condition = 0, - unsigned condition_feature = 0) const; + void CalculateContributions(const RegTree::FVec& feat, unsigned root_id, + bst_float* out_contribs, int condition = 0, + unsigned condition_feature = 0) const; /*! * \brief Recursive function that computes the feature attributions for a single tree. * \param feat dense feature vector, if the feature is missing the field is set to NaN @@ -504,12 +432,11 @@ class RegTree: public TreeModel { * \param condition_feature the index of the feature to fix * \param condition_fraction what fraction of the current weight matches our conditioning feature */ - inline void TreeShap(const RegTree::FVec& feat, bst_float *phi, - unsigned node_index, unsigned unique_depth, - PathElement *parent_unique_path, bst_float parent_zero_fraction, - bst_float parent_one_fraction, int parent_feature_index, - int condition, unsigned condition_feature, - bst_float condition_fraction) const; + void TreeShap(const RegTree::FVec& feat, bst_float* phi, unsigned node_index, + unsigned unique_depth, PathElement* parent_unique_path, + bst_float parent_zero_fraction, bst_float parent_one_fraction, + int parent_feature_index, int condition, + unsigned condition_feature, bst_float condition_fraction) const; /*! * \brief calculate the approximate feature contributions for the given root @@ -517,8 +444,8 @@ class RegTree: public TreeModel { * \param root_id starting root index of the instance * \param out_contribs output vector to hold the contributions */ - inline void CalculateContributionsApprox(const RegTree::FVec& feat, unsigned root_id, - bst_float *out_contribs) const; + void CalculateContributionsApprox(const RegTree::FVec& feat, unsigned root_id, + bst_float* out_contribs) const; /*! * \brief get next position of the tree given current pid * \param pid Current node id. @@ -539,16 +466,42 @@ class RegTree: public TreeModel { /*! * \brief calculate the mean value for each node, required for feature contributions */ - inline void FillNodeMeanValues(); + void FillNodeMeanValues(); private: - inline bst_float FillNodeMeanValue(int nid); - + // vector of nodes + std::vector nodes_; + // free node space, used during training process + std::vector deleted_nodes_; + // stats of nodes + std::vector stats_; std::vector node_mean_values_; + // allocate a new node, + // !!!!!! NOTE: may cause BUG here, nodes.resize + int AllocNode() { + if (param.num_deleted != 0) { + int nd = deleted_nodes_.back(); + deleted_nodes_.pop_back(); + --param.num_deleted; + return nd; + } + int nd = param.num_nodes++; + CHECK_LT(param.num_nodes, std::numeric_limits::max()) + << "number of nodes in the tree exceed 2^31"; + nodes_.resize(param.num_nodes); + stats_.resize(param.num_nodes); + return nd; + } + // delete a tree node, keep the parent field to allow trace back + void DeleteNode(int nid) { + CHECK_GE(nid, param.num_roots); + deleted_nodes_.push_back(nid); + nodes_[nid].MarkDelete(); + ++param.num_deleted; + } + bst_float FillNodeMeanValue(int nid); }; -// implementations of inline functions -// do not need to read if only use the model inline void RegTree::FVec::Init(size_t size) { Entry e; e.flag = -1; data_.resize(size); @@ -581,7 +534,8 @@ inline bool RegTree::FVec::IsMissing(size_t i) const { return data_[i].flag == -1; } -inline int RegTree::GetLeafIndex(const RegTree::FVec& feat, unsigned root_id) const { +inline int RegTree::GetLeafIndex(const RegTree::FVec& feat, + unsigned root_id) const { auto pid = static_cast(root_id); while (!(*this)[pid].IsLeaf()) { unsigned split_index = (*this)[pid].SplitIndex(); @@ -590,230 +544,6 @@ inline int RegTree::GetLeafIndex(const RegTree::FVec& feat, unsigned root_id) co return pid; } -inline bst_float RegTree::Predict(const RegTree::FVec& feat, unsigned root_id) const { - int pid = this->GetLeafIndex(feat, root_id); - return (*this)[pid].LeafValue(); -} - -inline void RegTree::FillNodeMeanValues() { - size_t num_nodes = this->param.num_nodes; - if (this->node_mean_values_.size() == num_nodes) { - return; - } - this->node_mean_values_.resize(num_nodes); - for (int root_id = 0; root_id < param.num_roots; ++root_id) { - this->FillNodeMeanValue(root_id); - } -} - -inline bst_float RegTree::FillNodeMeanValue(int nid) { - bst_float result; - auto& node = (*this)[nid]; - if (node.IsLeaf()) { - result = node.LeafValue(); - } else { - result = this->FillNodeMeanValue(node.LeftChild()) * this->Stat(node.LeftChild()).sum_hess; - result += this->FillNodeMeanValue(node.RightChild()) * this->Stat(node.RightChild()).sum_hess; - result /= this->Stat(nid).sum_hess; - } - this->node_mean_values_[nid] = result; - return result; -} - -inline void RegTree::CalculateContributionsApprox(const RegTree::FVec& feat, unsigned root_id, - bst_float *out_contribs) const { - CHECK_GT(this->node_mean_values_.size(), 0U); - // this follows the idea of http://blog.datadive.net/interpreting-random-forests/ - unsigned split_index = 0; - auto pid = static_cast(root_id); - // update bias value - bst_float node_value = this->node_mean_values_[pid]; - out_contribs[feat.Size()] += node_value; - if ((*this)[pid].IsLeaf()) { - // nothing to do anymore - return; - } - while (!(*this)[pid].IsLeaf()) { - split_index = (*this)[pid].SplitIndex(); - pid = this->GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index)); - bst_float new_value = this->node_mean_values_[pid]; - // update feature weight - out_contribs[split_index] += new_value - node_value; - node_value = new_value; - } - bst_float leaf_value = (*this)[pid].LeafValue(); - // update leaf feature weight - out_contribs[split_index] += leaf_value - node_value; -} - -// extend our decision path with a fraction of one and zero extensions -inline void ExtendPath(PathElement *unique_path, unsigned unique_depth, - bst_float zero_fraction, bst_float one_fraction, int feature_index) { - unique_path[unique_depth].feature_index = feature_index; - unique_path[unique_depth].zero_fraction = zero_fraction; - unique_path[unique_depth].one_fraction = one_fraction; - unique_path[unique_depth].pweight = (unique_depth == 0 ? 1.0f : 0.0f); - for (int i = unique_depth - 1; i >= 0; i--) { - unique_path[i+1].pweight += one_fraction * unique_path[i].pweight * (i + 1) - / static_cast(unique_depth + 1); - unique_path[i].pweight = zero_fraction * unique_path[i].pweight * (unique_depth - i) - / static_cast(unique_depth + 1); - } -} - -// undo a previous extension of the decision path -inline void UnwindPath(PathElement *unique_path, unsigned unique_depth, unsigned path_index) { - const bst_float one_fraction = unique_path[path_index].one_fraction; - const bst_float zero_fraction = unique_path[path_index].zero_fraction; - bst_float next_one_portion = unique_path[unique_depth].pweight; - - for (int i = unique_depth - 1; i >= 0; --i) { - if (one_fraction != 0) { - const bst_float tmp = unique_path[i].pweight; - unique_path[i].pweight = next_one_portion * (unique_depth + 1) - / static_cast((i + 1) * one_fraction); - next_one_portion = tmp - unique_path[i].pweight * zero_fraction * (unique_depth - i) - / static_cast(unique_depth + 1); - } else { - unique_path[i].pweight = (unique_path[i].pweight * (unique_depth + 1)) - / static_cast(zero_fraction * (unique_depth - i)); - } - } - - for (auto i = path_index; i < unique_depth; ++i) { - unique_path[i].feature_index = unique_path[i+1].feature_index; - unique_path[i].zero_fraction = unique_path[i+1].zero_fraction; - unique_path[i].one_fraction = unique_path[i+1].one_fraction; - } -} - -// determine what the total permuation weight would be if -// we unwound a previous extension in the decision path -inline bst_float UnwoundPathSum(const PathElement *unique_path, unsigned unique_depth, - unsigned path_index) { - const bst_float one_fraction = unique_path[path_index].one_fraction; - const bst_float zero_fraction = unique_path[path_index].zero_fraction; - bst_float next_one_portion = unique_path[unique_depth].pweight; - bst_float total = 0; - for (int i = unique_depth - 1; i >= 0; --i) { - if (one_fraction != 0) { - const bst_float tmp = next_one_portion * (unique_depth + 1) - / static_cast((i + 1) * one_fraction); - total += tmp; - next_one_portion = unique_path[i].pweight - tmp * zero_fraction * ((unique_depth - i) - / static_cast(unique_depth + 1)); - } else { - total += (unique_path[i].pweight / zero_fraction) / ((unique_depth - i) - / static_cast(unique_depth + 1)); - } - } - return total; -} - -// recursive computation of SHAP values for a decision tree -inline void RegTree::TreeShap(const RegTree::FVec& feat, bst_float *phi, - unsigned node_index, unsigned unique_depth, - PathElement *parent_unique_path, bst_float parent_zero_fraction, - bst_float parent_one_fraction, int parent_feature_index, - int condition, unsigned condition_feature, - bst_float condition_fraction) const { - const auto node = (*this)[node_index]; - - // stop if we have no weight coming down to us - if (condition_fraction == 0) return; - - // extend the unique path - PathElement *unique_path = parent_unique_path + unique_depth + 1; - std::copy(parent_unique_path, parent_unique_path + unique_depth + 1, unique_path); - - if (condition == 0 || condition_feature != static_cast(parent_feature_index)) { - ExtendPath(unique_path, unique_depth, parent_zero_fraction, - parent_one_fraction, parent_feature_index); - } - const unsigned split_index = node.SplitIndex(); - - // leaf node - if (node.IsLeaf()) { - for (unsigned i = 1; i <= unique_depth; ++i) { - const bst_float w = UnwoundPathSum(unique_path, unique_depth, i); - const PathElement &el = unique_path[i]; - phi[el.feature_index] += w * (el.one_fraction - el.zero_fraction) - * node.LeafValue() * condition_fraction; - } - - // internal node - } else { - // find which branch is "hot" (meaning x would follow it) - unsigned hot_index = 0; - if (feat.IsMissing(split_index)) { - hot_index = node.DefaultChild(); - } else if (feat.Fvalue(split_index) < node.SplitCond()) { - hot_index = node.LeftChild(); - } else { - hot_index = node.RightChild(); - } - const unsigned cold_index = (static_cast(hot_index) == node.LeftChild() ? - node.RightChild() : node.LeftChild()); - const bst_float w = this->Stat(node_index).sum_hess; - const bst_float hot_zero_fraction = this->Stat(hot_index).sum_hess / w; - const bst_float cold_zero_fraction = this->Stat(cold_index).sum_hess / w; - bst_float incoming_zero_fraction = 1; - bst_float incoming_one_fraction = 1; - - // see if we have already split on this feature, - // if so we undo that split so we can redo it for this node - unsigned path_index = 0; - for (; path_index <= unique_depth; ++path_index) { - if (static_cast(unique_path[path_index].feature_index) == split_index) break; - } - if (path_index != unique_depth + 1) { - incoming_zero_fraction = unique_path[path_index].zero_fraction; - incoming_one_fraction = unique_path[path_index].one_fraction; - UnwindPath(unique_path, unique_depth, path_index); - unique_depth -= 1; - } - - // divide up the condition_fraction among the recursive calls - bst_float hot_condition_fraction = condition_fraction; - bst_float cold_condition_fraction = condition_fraction; - if (condition > 0 && split_index == condition_feature) { - cold_condition_fraction = 0; - unique_depth -= 1; - } else if (condition < 0 && split_index == condition_feature) { - hot_condition_fraction *= hot_zero_fraction; - cold_condition_fraction *= cold_zero_fraction; - unique_depth -= 1; - } - - TreeShap(feat, phi, hot_index, unique_depth + 1, unique_path, - hot_zero_fraction * incoming_zero_fraction, incoming_one_fraction, - split_index, condition, condition_feature, hot_condition_fraction); - - TreeShap(feat, phi, cold_index, unique_depth + 1, unique_path, - cold_zero_fraction * incoming_zero_fraction, 0, - split_index, condition, condition_feature, cold_condition_fraction); - } -} - -inline void RegTree::CalculateContributions(const RegTree::FVec& feat, unsigned root_id, - bst_float *out_contribs, - int condition, - unsigned condition_feature) const { - // find the expected value of the tree's predictions - if (condition == 0) { - bst_float node_value = this->node_mean_values_[static_cast(root_id)]; - out_contribs[feat.Size()] += node_value; - } - - // Preallocate space for the unique path data - const int maxd = this->MaxDepth(root_id) + 2; - auto *unique_path_data = new PathElement[(maxd * (maxd + 1)) / 2]; - - TreeShap(feat, out_contribs, root_id, 0, unique_path_data, - 1, 1, -1, condition, condition_feature, 1); - delete[] unique_path_data; -} - /*! \brief get next position of the tree given current pid */ inline int RegTree::GetNext(int pid, bst_float fvalue, bool is_unknown) const { bst_float split_value = (*this)[pid].SplitCond(); diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 99b641131912..d2def6663693 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -272,7 +272,6 @@ class GBTree : public GradientBooster { // create new tree std::unique_ptr ptr(new RegTree()); ptr->param.InitAllowUnknown(this->cfg_); - ptr->InitModel(); new_trees.push_back(ptr.get()); ret->push_back(std::move(ptr)); } else if (tparam_.process_type == kUpdate) { diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index c03587c148b4..5052f1383bf0 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -169,4 +169,240 @@ std::string RegTree::DumpModel(const FeatureMap& fmap, } return fo.str(); } +void RegTree::FillNodeMeanValues() { + size_t num_nodes = this->param.num_nodes; + if (this->node_mean_values_.size() == num_nodes) { + return; + } + this->node_mean_values_.resize(num_nodes); + for (int root_id = 0; root_id < param.num_roots; ++root_id) { + this->FillNodeMeanValue(root_id); + } +} + +bst_float RegTree::FillNodeMeanValue(int nid) { + bst_float result; + auto& node = (*this)[nid]; + if (node.IsLeaf()) { + result = node.LeafValue(); + } else { + result = this->FillNodeMeanValue(node.LeftChild()) * this->Stat(node.LeftChild()).sum_hess; + result += this->FillNodeMeanValue(node.RightChild()) * this->Stat(node.RightChild()).sum_hess; + result /= this->Stat(nid).sum_hess; + } + this->node_mean_values_[nid] = result; + return result; +} + +void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat, + unsigned root_id, + bst_float *out_contribs) const { + CHECK_GT(this->node_mean_values_.size(), 0U); + // this follows the idea of http://blog.datadive.net/interpreting-random-forests/ + unsigned split_index = 0; + auto pid = static_cast(root_id); + // update bias value + bst_float node_value = this->node_mean_values_[pid]; + out_contribs[feat.Size()] += node_value; + if ((*this)[pid].IsLeaf()) { + // nothing to do anymore + return; + } + while (!(*this)[pid].IsLeaf()) { + split_index = (*this)[pid].SplitIndex(); + pid = this->GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index)); + bst_float new_value = this->node_mean_values_[pid]; + // update feature weight + out_contribs[split_index] += new_value - node_value; + node_value = new_value; + } + bst_float leaf_value = (*this)[pid].LeafValue(); + // update leaf feature weight + out_contribs[split_index] += leaf_value - node_value; +} + +// Used by TreeShap +// data we keep about our decision path +// note that pweight is included for convenience and is not tied with the other attributes +// the pweight of the i'th path element is the permuation weight of paths with i-1 ones in them +struct PathElement { + int feature_index; + bst_float zero_fraction; + bst_float one_fraction; + bst_float pweight; + PathElement() = default; + PathElement(int i, bst_float z, bst_float o, bst_float w) : + feature_index(i), zero_fraction(z), one_fraction(o), pweight(w) {} +}; + +// extend our decision path with a fraction of one and zero extensions +void ExtendPath(PathElement *unique_path, unsigned unique_depth, + bst_float zero_fraction, bst_float one_fraction, + int feature_index) { + unique_path[unique_depth].feature_index = feature_index; + unique_path[unique_depth].zero_fraction = zero_fraction; + unique_path[unique_depth].one_fraction = one_fraction; + unique_path[unique_depth].pweight = (unique_depth == 0 ? 1.0f : 0.0f); + for (int i = unique_depth - 1; i >= 0; i--) { + unique_path[i+1].pweight += one_fraction * unique_path[i].pweight * (i + 1) + / static_cast(unique_depth + 1); + unique_path[i].pweight = zero_fraction * unique_path[i].pweight * (unique_depth - i) + / static_cast(unique_depth + 1); + } +} + +// undo a previous extension of the decision path +void UnwindPath(PathElement *unique_path, unsigned unique_depth, + unsigned path_index) { + const bst_float one_fraction = unique_path[path_index].one_fraction; + const bst_float zero_fraction = unique_path[path_index].zero_fraction; + bst_float next_one_portion = unique_path[unique_depth].pweight; + + for (int i = unique_depth - 1; i >= 0; --i) { + if (one_fraction != 0) { + const bst_float tmp = unique_path[i].pweight; + unique_path[i].pweight = next_one_portion * (unique_depth + 1) + / static_cast((i + 1) * one_fraction); + next_one_portion = tmp - unique_path[i].pweight * zero_fraction * (unique_depth - i) + / static_cast(unique_depth + 1); + } else { + unique_path[i].pweight = (unique_path[i].pweight * (unique_depth + 1)) + / static_cast(zero_fraction * (unique_depth - i)); + } + } + + for (auto i = path_index; i < unique_depth; ++i) { + unique_path[i].feature_index = unique_path[i+1].feature_index; + unique_path[i].zero_fraction = unique_path[i+1].zero_fraction; + unique_path[i].one_fraction = unique_path[i+1].one_fraction; + } +} + +// determine what the total permuation weight would be if +// we unwound a previous extension in the decision path +bst_float UnwoundPathSum(const PathElement *unique_path, unsigned unique_depth, + unsigned path_index) { + const bst_float one_fraction = unique_path[path_index].one_fraction; + const bst_float zero_fraction = unique_path[path_index].zero_fraction; + bst_float next_one_portion = unique_path[unique_depth].pweight; + bst_float total = 0; + for (int i = unique_depth - 1; i >= 0; --i) { + if (one_fraction != 0) { + const bst_float tmp = next_one_portion * (unique_depth + 1) + / static_cast((i + 1) * one_fraction); + total += tmp; + next_one_portion = unique_path[i].pweight - tmp * zero_fraction * ((unique_depth - i) + / static_cast(unique_depth + 1)); + } else { + total += (unique_path[i].pweight / zero_fraction) / ((unique_depth - i) + / static_cast(unique_depth + 1)); + } + } + return total; +} + +// recursive computation of SHAP values for a decision tree +void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi, + unsigned node_index, unsigned unique_depth, + PathElement *parent_unique_path, + bst_float parent_zero_fraction, + bst_float parent_one_fraction, int parent_feature_index, + int condition, unsigned condition_feature, + bst_float condition_fraction) const { + const auto node = (*this)[node_index]; + + // stop if we have no weight coming down to us + if (condition_fraction == 0) return; + + // extend the unique path + PathElement *unique_path = parent_unique_path + unique_depth + 1; + std::copy(parent_unique_path, parent_unique_path + unique_depth + 1, unique_path); + + if (condition == 0 || condition_feature != static_cast(parent_feature_index)) { + ExtendPath(unique_path, unique_depth, parent_zero_fraction, + parent_one_fraction, parent_feature_index); + } + const unsigned split_index = node.SplitIndex(); + + // leaf node + if (node.IsLeaf()) { + for (unsigned i = 1; i <= unique_depth; ++i) { + const bst_float w = UnwoundPathSum(unique_path, unique_depth, i); + const PathElement &el = unique_path[i]; + phi[el.feature_index] += w * (el.one_fraction - el.zero_fraction) + * node.LeafValue() * condition_fraction; + } + + // internal node + } else { + // find which branch is "hot" (meaning x would follow it) + unsigned hot_index = 0; + if (feat.IsMissing(split_index)) { + hot_index = node.DefaultChild(); + } else if (feat.Fvalue(split_index) < node.SplitCond()) { + hot_index = node.LeftChild(); + } else { + hot_index = node.RightChild(); + } + const unsigned cold_index = (static_cast(hot_index) == node.LeftChild() ? + node.RightChild() : node.LeftChild()); + const bst_float w = this->Stat(node_index).sum_hess; + const bst_float hot_zero_fraction = this->Stat(hot_index).sum_hess / w; + const bst_float cold_zero_fraction = this->Stat(cold_index).sum_hess / w; + bst_float incoming_zero_fraction = 1; + bst_float incoming_one_fraction = 1; + + // see if we have already split on this feature, + // if so we undo that split so we can redo it for this node + unsigned path_index = 0; + for (; path_index <= unique_depth; ++path_index) { + if (static_cast(unique_path[path_index].feature_index) == split_index) break; + } + if (path_index != unique_depth + 1) { + incoming_zero_fraction = unique_path[path_index].zero_fraction; + incoming_one_fraction = unique_path[path_index].one_fraction; + UnwindPath(unique_path, unique_depth, path_index); + unique_depth -= 1; + } + + // divide up the condition_fraction among the recursive calls + bst_float hot_condition_fraction = condition_fraction; + bst_float cold_condition_fraction = condition_fraction; + if (condition > 0 && split_index == condition_feature) { + cold_condition_fraction = 0; + unique_depth -= 1; + } else if (condition < 0 && split_index == condition_feature) { + hot_condition_fraction *= hot_zero_fraction; + cold_condition_fraction *= cold_zero_fraction; + unique_depth -= 1; + } + + TreeShap(feat, phi, hot_index, unique_depth + 1, unique_path, + hot_zero_fraction * incoming_zero_fraction, incoming_one_fraction, + split_index, condition, condition_feature, hot_condition_fraction); + + TreeShap(feat, phi, cold_index, unique_depth + 1, unique_path, + cold_zero_fraction * incoming_zero_fraction, 0, + split_index, condition, condition_feature, cold_condition_fraction); + } +} + +void RegTree::CalculateContributions(const RegTree::FVec &feat, + unsigned root_id, bst_float *out_contribs, + int condition, + unsigned condition_feature) const { + // find the expected value of the tree's predictions + if (condition == 0) { + bst_float node_value = this->node_mean_values_[static_cast(root_id)]; + out_contribs[feat.Size()] += node_value; + } + + // Preallocate space for the unique path data + const int maxd = this->MaxDepth(root_id) + 2; + auto *unique_path_data = new PathElement[(maxd * (maxd + 1)) / 2]; + + TreeShap(feat, out_contribs, root_id, 0, unique_path_data, + 1, 1, -1, condition, condition_feature, 1); + delete[] unique_path_data; +} } // namespace xgboost diff --git a/src/tree/updater_prune.cc b/src/tree/updater_prune.cc index 61411e40aede..8de43e7ad195 100644 --- a/src/tree/updater_prune.cc +++ b/src/tree/updater_prune.cc @@ -48,7 +48,7 @@ class TreePruner: public TreeUpdater { inline int TryPruneLeaf(RegTree &tree, int nid, int depth, int npruned) { // NOLINT(*) if (tree[nid].IsRoot()) return npruned; int pid = tree[nid].Parent(); - RegTree::NodeStat &s = tree.Stat(pid); + RTreeNodeStat &s = tree.Stat(pid); ++s.leaf_child_cnt; if (s.leaf_child_cnt >= 2 && param_.NeedPrune(s.loss_chg, depth - 1)) { // need to be pruned diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index f9e575c29b64..752fdee92e01 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -10,7 +10,6 @@ TEST(cpu_predictor, Test) { std::vector> trees; trees.push_back(std::unique_ptr(new RegTree)); - trees.back()->InitModel(); (*trees.back())[0].SetLeaf(1.5f); (*trees.back()).Stat(0).sum_hess = 1.0f; gbm::GBTreeModel model(0.5); diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index b345a6ef3c9a..6302e05babdb 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -35,7 +35,6 @@ TEST(gpu_predictor, Test) { std::vector> trees; trees.push_back(std::unique_ptr(new RegTree())); - trees.back()->InitModel(); (*trees.back())[0].SetLeaf(1.5f); (*trees.back()).Stat(0).sum_hess = 1.0f; gbm::GBTreeModel model(0.5); @@ -181,7 +180,6 @@ TEST(gpu_predictor, MGPU_Test) { std::vector> trees; trees.push_back(std::unique_ptr(new RegTree())); - trees.back()->InitModel(); (*trees.back())[0].SetLeaf(1.5f); (*trees.back()).Stat(0).sum_hess = 1.0f; gbm::GBTreeModel model(0.5); diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 6a36eecc2208..83896ebc8536 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -291,8 +291,6 @@ TEST(GpuHist, EvaluateSplits) { false); RegTree tree; - tree.InitModel(); - MetaInfo info; info.num_row_ = n_rows; info.num_col_ = n_cols; @@ -339,7 +337,6 @@ TEST(GpuHist, ApplySplit) { // Initialize GPUHistMaker hist_maker.param_ = param; RegTree tree; - tree.InitModel(); DeviceSplitCandidate candidate; candidate.Update(2, kLeftDir, diff --git a/tests/cpp/tree/test_prune.cc b/tests/cpp/tree/test_prune.cc index 0b1878e4164a..fbebf47b7415 100644 --- a/tests/cpp/tree/test_prune.cc +++ b/tests/cpp/tree/test_prune.cc @@ -31,7 +31,6 @@ TEST(Updater, Prune) { // prepare tree RegTree tree = RegTree(); - tree.InitModel(); tree.param.InitAllowUnknown(cfg); std::vector trees {&tree}; // prepare pruner diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 5ac8575a136d..d91b69c48528 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -122,7 +122,6 @@ class QuantileHistMock : public QuantileHistMaker { gmat.Init((*dmat).get(), max_bins); RegTree tree = RegTree(); - tree.InitModel(); tree.param.InitAllowUnknown(cfg); std::vector gpair = @@ -134,7 +133,6 @@ class QuantileHistMock : public QuantileHistMaker { void TestBuildHist() { RegTree tree = RegTree(); - tree.InitModel(); tree.param.InitAllowUnknown(cfg); size_t constexpr max_bins = 4; @@ -146,7 +144,6 @@ class QuantileHistMock : public QuantileHistMaker { void TestEvaluateSplit() { RegTree tree = RegTree(); - tree.InitModel(); tree.param.InitAllowUnknown(cfg); builder_->TestEvaluateSplit(gmatb_, tree); diff --git a/tests/cpp/tree/test_refresh.cc b/tests/cpp/tree/test_refresh.cc index 78d2db2f3ee5..d1e66edb1757 100644 --- a/tests/cpp/tree/test_refresh.cc +++ b/tests/cpp/tree/test_refresh.cc @@ -25,7 +25,6 @@ TEST(Updater, Refresh) { {"reg_lambda", "1"}}; RegTree tree = RegTree(); - tree.InitModel(); tree.param.InitAllowUnknown(cfg); std::vector trees {&tree}; std::unique_ptr refresher(TreeUpdater::Create("refresh"));