From 53e6b57b4520ee3bb71551acd5bb54e12f865997 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 6 Mar 2020 16:47:45 +0800 Subject: [PATCH 1/5] [POC] Implement Exact tree method for multi-target. * Added a new exact tree method. * Specialize many utilities for it. --- amalgamation/xgboost-all0.cc | 1 + demo/multi-target/.gitignore | 1 + demo/multi-target/regression.py | 142 ++++ include/xgboost/base.h | 27 +- include/xgboost/data.h | 7 +- include/xgboost/learner.h | 22 - include/xgboost/model.h | 31 + include/xgboost/tree_model.h | 152 +++- include/xgboost/tree_updater.h | 12 +- python-package/xgboost/core.py | 33 +- python-package/xgboost/experimental.py | 21 + python-package/xgboost/sklearn.py | 10 +- src/c_api/c_api.cc | 7 +- src/common/observer.h | 4 +- src/common/quantile.h | 3 +- src/data/array_interface.h | 2 +- src/data/data.cc | 35 +- src/data/data.cu | 2 +- src/gbm/gblinear.cc | 2 +- src/gbm/gblinear_model.h | 1 - src/gbm/gbm.cc | 2 +- src/gbm/gbtree.cc | 27 +- src/gbm/gbtree.h | 9 +- src/gbm/gbtree_model.h | 12 +- src/learner.cc | 54 +- src/metric/multiclass_metric.cu | 93 ++- src/predictor/cpu_predictor.cc | 150 +++- src/predictor/gpu_predictor.cu | 2 +- src/tree/param.h | 7 +- src/tree/split_evaluator.cc | 1 + src/tree/tree_model.cc | 20 +- src/tree/tree_updater.cc | 9 +- src/tree/updater_colmaker.cc | 946 ++++++++++++------------- src/tree/updater_exact.cc | 448 ++++++++++++ src/tree/updater_exact.h | 516 ++++++++++++++ src/tree/updater_gpu_hist.cu | 4 +- src/tree/updater_histmaker.cc | 4 +- src/tree/updater_prune.cc | 11 +- src/tree/updater_quantile_hist.cc | 12 +- src/tree/updater_quantile_hist.h | 3 +- src/tree/updater_refresh.cc | 2 +- src/tree/updater_skmaker.cc | 2 +- src/tree/updater_sync.cc | 2 +- tests/cpp/common/test_hist_util.h | 2 +- tests/cpp/data/test_metainfo.cc | 20 + tests/cpp/data/test_metainfo.cu | 18 +- tests/cpp/gbm/test_gbtree.cc | 2 +- tests/cpp/helpers.cc | 24 +- tests/cpp/helpers.h | 12 +- tests/cpp/predictor/test_predictor.cc | 1 + tests/cpp/tree/test_exact.cc | 196 +++++ tests/cpp/tree/test_gpu_hist.cu | 4 +- tests/cpp/tree/test_histmaker.cc | 8 +- tests/cpp/tree/test_prune.cc | 3 +- tests/cpp/tree/test_quantile_hist.cc | 41 +- tests/cpp/tree/test_refresh.cc | 11 +- tests/cpp/tree/test_tree_model.cc | 2 +- tests/cpp/tree/test_tree_stat.cc | 3 +- 58 files changed, 2471 insertions(+), 727 deletions(-) create mode 100644 demo/multi-target/.gitignore create mode 100644 demo/multi-target/regression.py create mode 100644 python-package/xgboost/experimental.py create mode 100644 src/tree/updater_exact.cc create mode 100644 src/tree/updater_exact.h create mode 100644 tests/cpp/tree/test_exact.cc diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 6e8c09b7d5c7..43a21e4e37b3 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -52,6 +52,7 @@ #include "../src/tree/tree_model.cc" #include "../src/tree/tree_updater.cc" #include "../src/tree/updater_colmaker.cc" +#include "../src/tree/updater_exact.cc" #include "../src/tree/updater_quantile_hist.cc" #include "../src/tree/updater_prune.cc" #include "../src/tree/updater_refresh.cc" diff --git a/demo/multi-target/.gitignore b/demo/multi-target/.gitignore new file mode 100644 index 000000000000..aab52d906fa2 --- /dev/null +++ b/demo/multi-target/.gitignore @@ -0,0 +1 @@ +*.png \ No newline at end of file diff --git a/demo/multi-target/regression.py b/demo/multi-target/regression.py new file mode 100644 index 000000000000..1d1994782792 --- /dev/null +++ b/demo/multi-target/regression.py @@ -0,0 +1,142 @@ +'''The example is taken from: +https://scikit-learn.org/stable/auto_examples/tree/plot_tree_regression_multioutput.html#sphx-glr-auto-examples-tree-plot-tree-regression-multioutput-py + +Multi-target tree may have lower accuracy due to smaller model capacity, but +provides better computation performance for prediction. + +The current implementation supports only exact tree method and is considered as +highly experimental. We do not recommend any real world usage. + +There are 3 different ways to train a multi target model. + +- Train 1 model for each target manually. See `train_stacked_native` below. +- Train 1 stack of trees for each target by XGBoost. This is the default + implementation with `output_type` set to `single`. +- Train 1 stack of trees for all target variables, with the tree leaf being a + vector. This can be enabled by setting `output_type` to `multi`. + +''' + +import numpy as np +from matplotlib import pyplot as plt +import xgboost as xgb +from xgboost.experimental import XGBMultiRegressor +import argparse + +# Generate some random data with y being a circle. +rng = np.random.RandomState(1994) +X = np.sort(200 * rng.rand(100, 1) - 100, axis=0) +y = np.array([np.pi * np.sin(X).ravel(), np.pi * np.cos(X).ravel()]).T +y[::5, :] += (0.5 - rng.rand(20, 2)) + +boosted_rounds = 32 + +y = y - y.min() +y: np.ndarray = y / y.max() +y = y.copy() + + +def plot_predt(y, y_predt, name): + '''Plot the output prediction along with labels. + Parameters + ---------- + y : np.ndarray + labels + y_predt : np.ndarray + prediction from XGBoost. + name : str + output file name for matplotlib. + ''' + s = 25 + plt.scatter(y[:, 0], y[:, 1], c="navy", s=s, + edgecolor="black", label="data") + plt.scatter(y_predt[:, 0], y_predt[:, 1], c="cornflowerblue", s=s, + edgecolor="black", label="max_depth=2") + plt.xlim([-1, 2]) + plt.ylim([-1, 2]) + plt.savefig(name + '.png') + plt.close() + + +def train_multi_skl(): + '''Train a multi-target regression with XGBoost's scikit-learn interface. This + method demos training multi-target trees with each vector as leaf value, + also training a model that uses single target tree with one stack of trees + for each target variable. + + ''' + # Train with vector leaf trees. + reg = XGBMultiRegressor(output_type='multi', + num_targets=y.shape[1], + n_estimators=boosted_rounds) + reg.fit(X, y, eval_set=[(X, y)]) + y_predt = reg.predict(X) + plot_predt(y, y_predt, 'skl-multi') + + # Train 1 stack of trees for each target variable. + reg = XGBMultiRegressor(output_type='single', + num_targets=y.shape[1], + n_estimators=boosted_rounds) + reg.fit(X, y, eval_set=[(X, y)]) + y_predt = reg.predict(X) + plot_predt(y, y_predt, 'skl-sinlge') + + +def train_multi_native(): + '''Train a multi-target regression with native XGBoost interface. This method + demos training multi-target trees with each vector as leaf value, also + training a model that uses single target tree with one stack of trees for + each target variable. + + ''' + d = xgb.DMatrix(X, y) + # Train with vector leaf trees. + booster = xgb.train({'tree_method': 'exact', + 'nthread': 16, + 'output_type': 'multi', + 'num_targets': y.shape[1], + 'objective': 'reg:squarederror' + }, d, + num_boost_round=boosted_rounds, + evals=[(d, 'Train')]) + y_predt = booster.predict(d) + plot_predt(y, y_predt, 'native-multi') + + # Train 1 stack of trees for each target variable. + booster = xgb.train({'tree_method': 'exact', + 'nthread': 16, + 'output_type': 'single', + 'num_targets': y.shape[1], + 'objective': 'reg:squarederror' + }, d, + num_boost_round=boosted_rounds, + evals=[(d, 'Train')]) + y_predt = booster.predict(d) + plot_predt(y, y_predt, 'native-single') + + +def train_stacked_native(): + '''Train 2 XGBoost models, each one targeting a single output variable.''' + # Extract the first target variable + d = xgb.DMatrix(X, y[:, 0].copy()) + params = {'tree_method': 'exact', + 'objective': 'reg:squarederror'} + booster = xgb.train( + params, d, num_boost_round=boosted_rounds, evals=[(d, 'Train')]) + y_predt_0 = booster.predict(d) + + # Extract the second target variable + d = xgb.DMatrix(X, y[:, 1].copy()) + booster = xgb.train(params, d, num_boost_round=boosted_rounds) + y_predt_1 = booster.predict(d) + y_predt = np.stack([y_predt_0, y_predt_1], axis=-1) + plot_predt(y, y_predt, 'stacked') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + args = parser.parse_args() + + train_multi_native() + train_multi_skl() + train_stacked_native() diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 4802426112e5..20009112a2f8 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -57,16 +57,23 @@ #if defined(__GNUC__) && ((__GNUC__ == 4 && __GNUC_MINOR__ >= 8) || __GNUC__ > 4) && \ !defined(__CUDACC__) #include +#include #define XGBOOST_PARALLEL_SORT(X, Y, Z) __gnu_parallel::sort((X), (Y), (Z)) #define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) \ __gnu_parallel::stable_sort((X), (Y), (Z)) +#define XGBOOST_PARALLEL_ACCUMULATE(__BEG, __END, __INIT, __OP) \ + __gnu_parallel::accumulate(__BEG, __END, __INIT, __OP) #elif defined(_MSC_VER) && (!__INTEL_COMPILER) #include #define XGBOOST_PARALLEL_SORT(X, Y, Z) concurrency::parallel_sort((X), (Y), (Z)) #define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) std::stable_sort((X), (Y), (Z)) +#define XGBOOST_PARALLEL_ACCUMULATE(__BEG, __END, __INIT, __OP) \ + std::accumulate(__BEG, __END, __INIT, __OP) #else #define XGBOOST_PARALLEL_SORT(X, Y, Z) std::sort((X), (Y), (Z)) #define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) std::stable_sort((X), (Y), (Z)) +#define XGBOOST_PARALLEL_ACCUMULATE(__BEG, __END, __INIT, __OP) \ + std::accumulate(__BEG, __END, __INIT, __OP) #endif // GLIBC VERSION #if defined(__GNUC__) @@ -135,8 +142,8 @@ class GradientPairInternal { /*! \brief second order gradient statistics */ T hess_; - XGBOOST_DEVICE void SetGrad(T g) { grad_ = g; } - XGBOOST_DEVICE void SetHess(T h) { hess_ = h; } + XGBOOST_DEVICE void SetGrad(T g) { grad_ = std::move(g); } + XGBOOST_DEVICE void SetHess(T h) { hess_ = std::move(h); } public: using ValueT = T; @@ -150,12 +157,9 @@ class GradientPairInternal { a += b; } - XGBOOST_DEVICE GradientPairInternal() : grad_(0), hess_(0) {} - - XGBOOST_DEVICE GradientPairInternal(T grad, T hess) { - SetGrad(grad); - SetHess(hess); - } + constexpr XGBOOST_DEVICE GradientPairInternal() : grad_(0), hess_(0) {} + constexpr XGBOOST_DEVICE GradientPairInternal(T grad, T hess) + : grad_{std::move(grad)}, hess_{std::move(hess)} {} // Copy constructor if of same value type, marked as default to be trivially_copyable GradientPairInternal(const GradientPairInternal &g) = default; @@ -168,8 +172,11 @@ class GradientPairInternal { SetHess(g.GetHess()); } - XGBOOST_DEVICE T GetGrad() const { return grad_; } - XGBOOST_DEVICE T GetHess() const { return hess_; } + XGBOOST_DEVICE T const& GetGrad() const { return grad_; } + XGBOOST_DEVICE T const& GetHess() const { return hess_; } + + XGBOOST_DEVICE T& GetGrad() { return grad_; } + XGBOOST_DEVICE T& GetHess() { return hess_; } XGBOOST_DEVICE GradientPairInternal &operator+=( const GradientPairInternal &rhs) { diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 57babfafe126..903af20e44bc 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -50,6 +50,8 @@ class MetaInfo { uint64_t num_nonzero_{0}; // NOLINT /*! \brief label of each instance */ HostDeviceVector labels_; // NOLINT + bst_row_t labels_rows; + bst_feature_t labels_cols { 1 }; /*! * \brief the index of begin and end of a group * needed when the learning task is ranking. @@ -156,7 +158,7 @@ class MetaInfo { * * Right now only 1 column is permitted. */ - void SetInfo(const char* key, std::string const& interface_str); + void SetInfo(const char* key, std::string const& interface_str, int32_t device); /* * \brief Extend with other MetaInfo. @@ -169,6 +171,7 @@ class MetaInfo { void Extend(MetaInfo const& that, bool accumulate_rows); private: + void SetInfoDevice(const char* key, std::string const& interface_str); /*! \brief argsort of labels */ mutable std::vector label_order_cache_; }; @@ -446,7 +449,7 @@ class DMatrix { this->Info().SetInfo(key, dptr, dtype, num); } virtual void SetInfo(const char* key, std::string const& interface_str) { - this->Info().SetInfo(key, interface_str); + this->Info().SetInfo(key, interface_str, 0); } /*! \brief meta information of the dataset */ virtual const MetaInfo& Info() const = 0; diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index a608bc1b8206..13f1b687ffa8 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -226,27 +226,5 @@ class Learner : public Model, public Configurable, public rabit::Serializable { /*! \brief Training parameter. */ GenericParameter generic_parameters_; }; - -struct LearnerModelParamLegacy; - -/* - * \brief Basic Model Parameters, used to describe the booster. - */ -struct LearnerModelParam { - /* \brief global bias */ - bst_float base_score { 0.5f }; - /* \brief number of features */ - uint32_t num_feature { 0 }; - /* \brief number of classes, if it is multi-class classification */ - uint32_t num_output_group { 0 }; - - LearnerModelParam() = default; - // As the old `LearnerModelParamLegacy` is still used by binary IO, we keep - // this one as an immutable copy. - LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin); - /* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */ - bool Initialized() const { return num_feature != 0; } -}; - } // namespace xgboost #endif // XGBOOST_LEARNER_H_ diff --git a/include/xgboost/model.h b/include/xgboost/model.h index 3b661ae814b8..bd24a7fea3dd 100644 --- a/include/xgboost/model.h +++ b/include/xgboost/model.h @@ -6,6 +6,8 @@ #ifndef XGBOOST_MODEL_H_ #define XGBOOST_MODEL_H_ +#include + namespace dmlc { class Stream; } // namespace dmlc @@ -41,6 +43,35 @@ struct Configurable { */ virtual void SaveConfig(Json* out) const = 0; }; + +struct LearnerModelParamLegacy; + +enum class OutputType : int32_t { + kSingle, + kMulti +}; + +/* + * \brief Basic Model Parameters, used to describe the booster. + */ +struct LearnerModelParam { + /* \brief global bias */ + float base_score { 0.5 }; + /* \brief number of features */ + uint32_t num_feature { 0 }; + /* \brief number of classes, if it is multi-class classification */ + uint32_t num_output_group { 0 }; + /* \brief number of target variables. */ + uint32_t num_targets { 1 }; + OutputType output_type { OutputType::kSingle }; + + LearnerModelParam() = default; + // As the old `LearnerModelParamLegacy` is still used by binary IO, we keep + // this one as an immutable copy. + LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin); + /* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */ + bool Initialized() const { return num_feature != 0; } +}; } // namespace xgboost #endif // XGBOOST_MODEL_H_ diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index e7f6dc8ec089..f95e854f8121 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -99,11 +99,53 @@ struct RTreeNodeStat { } }; +class MultiTargetTreeNodeStat { + std::vector loss_chg_; + std::vector sum_hess_; + std::vector base_weight_; + size_t targets_; + + public: + explicit MultiTargetTreeNodeStat(size_t targets) : targets_{targets} {} + void Set(bst_node_t nidx, float loss_chg, common::Span weight, + common::Span hess) { + if (loss_chg_.size() < static_cast(nidx + 1)) { + loss_chg_.resize(nidx + 1); + } + loss_chg_[nidx] = loss_chg; + size_t beg = nidx * targets_; + size_t end = beg + targets_; + if (sum_hess_.size() < end) { + sum_hess_.resize(end); + base_weight_.resize(end); + } + for (size_t i = beg; i < end; ++i) { + sum_hess_[i] = hess[i - beg]; + base_weight_[i] = weight[i - beg]; + } + } + void Prune(bst_node_t nidx) { + loss_chg_[nidx] = std::numeric_limits::quiet_NaN(); + } + bool IsDeleted(bst_node_t nidx) const { + return std::isnan(loss_chg_[nidx]); + } +}; + /*! * \brief define regression tree to be the most common tree model. * This is the data structure used in xgboost's major tree models. */ class RegTree : public Model { + public: + enum TreeKind : int { + kSingle, + kMulti + }; + + private: + TreeKind kind_ {kSingle}; + public: using SplitCondT = bst_float; static constexpr bst_node_t kInvalidNodeId {-1}; @@ -150,7 +192,7 @@ class RegTree : public Model { return cleft_ == kInvalidNodeId; } /*! \return get leaf value of leaf node */ - XGBOOST_DEVICE bst_float LeafValue() const { + XGBOOST_DEVICE bst_float SinlgeLeafValue() const { return (this->info_).leaf_value; } /*! \return get split condition of the node */ @@ -247,6 +289,30 @@ class RegTree : public Model { Info info_; }; + explicit RegTree(bst_feature_t leaf_size = 1, TreeKind kind = kSingle) : + kind_{kind}, leaf_size_{leaf_size}, multi_target_stats_{leaf_size} { + param.num_nodes = 1; + param.num_deleted = 0; + nodes_.resize(param.num_nodes); + stats_.resize(param.num_nodes); + for (int i = 0; i < param.num_nodes; i ++) { + nodes_[i].SetLeaf(0.0f); + nodes_[i].SetParent(kInvalidNodeId); + } + + if (leaf_size_ != 1) { + leaf_values_.resize(leaf_size_); + CHECK_EQ(kind_, kMulti); + } + } + /*! + * \brief Return tree kind, kSingle or kMulti. + */ + TreeKind Kind() const { return kind_; } + /*! + * \brief Return the size of leaf. + */ + bst_feature_t LeafSize() const { return leaf_size_; } /*! * \brief change a non leaf node to a leaf node, delete its children * \param rid node id of the node @@ -275,19 +341,6 @@ class RegTree : public Model { this->ChangeToLeaf(rid, value); } - /*! \brief model parameter */ - TreeParam param; - /*! \brief constructor */ - RegTree() { - param.num_nodes = 1; - param.num_deleted = 0; - nodes_.resize(param.num_nodes); - stats_.resize(param.num_nodes); - for (int i = 0; i < param.num_nodes; i ++) { - nodes_[i].SetLeaf(0.0f); - nodes_[i].SetParent(kInvalidNodeId); - } - } /*! \brief get node given nid */ Node& operator[](int nid) { return nodes_[nid]; @@ -402,6 +455,33 @@ class RegTree : public Model { this->Stat(pright) = {0.0f, right_sum, right_leaf_weight}; } + void ExpandNode(int nid, unsigned split_index, + bst_float split_value, + bool default_left, std::vector const& base_weight, + std::vector const& left_leaf_weight, + std::vector const& right_leaf_weight, + bst_float loss_change, + std::vector const& sum_hess, + std::vector const& left_sum, std::vector const& right_sum) { + int pleft = this->AllocNode(); + int pright = this->AllocNode(); + auto &node = nodes_[nid]; + CHECK(node.IsLeaf()); + node.SetLeftChild(pleft); + node.SetRightChild(pright); + nodes_[node.LeftChild()].SetParent(nid, true); + nodes_[node.RightChild()].SetParent(nid, false); + node.SetSplit(split_index, split_value, + default_left); + this->SetLeaf(left_leaf_weight, pleft); + this->SetLeaf(right_leaf_weight, pright); + this->multi_target_stats_.Set(nid, loss_change, + common::Span{base_weight}, + common::Span{sum_hess}); + this->multi_target_stats_.Set(pleft, 0, {left_leaf_weight}, {left_sum}); + this->multi_target_stats_.Set(pright, 0, {right_leaf_weight}, {right_sum}); + } + /*! * \brief get current depth * \param nid node id @@ -489,6 +569,42 @@ class RegTree : public Model { }; std::vector data_; }; + + common::Span VectorLeafValue(bst_node_t nidx) const { + CHECK_EQ(kind_, kMulti); + auto s = common::Span {leaf_values_}.subspan(nidx * leaf_size_, leaf_size_); + return s; + } + float LeafValue(bst_node_t nidx) const { + CHECK_EQ(kind_, kSingle); + return (*this)[nidx].SinlgeLeafValue(); + } + + void SetLeaf(std::vector const& leaf, bst_node_t nid, + std::vector const& sum_hess = {}) { + (*this)[nid].SetLeaf(0); + auto offset = LeafSize() * nid; + if (leaf_values_.size() < offset + LeafSize()) { + leaf_values_.resize(offset + LeafSize()); + } + CHECK_EQ(leaf.size(), LeafSize()); + for (size_t i = 0; i < LeafSize(); ++i) { + leaf_values_[i + offset] = leaf[i]; + } + (*this)[nid].SetLeftChild(kInvalidNodeId); + (*this)[nid].SetRightChild(kInvalidNodeId); + + if (sum_hess.size() != 0) { + this->multi_target_stats_.Set(nid, 0, leaf, sum_hess); + } + } + void SetLeaf(float const& leaf, bst_node_t nid, + double sum_hess = 0) { + (*this)[nid].SetLeaf(leaf); + this->Stat(nid).loss_chg = 0; + this->Stat(nid).base_weight = leaf; + this->Stat(nid).sum_hess = sum_hess; + } /*! * \brief get the leaf index * \param feat dense feature vector, if the feature is missing the field is set to NaN @@ -554,6 +670,9 @@ class RegTree : public Model { */ void FillNodeMeanValues(); + /*! \brief model parameter */ + TreeParam param; + private: // vector of nodes std::vector nodes_; @@ -562,6 +681,11 @@ class RegTree : public Model { // stats of nodes std::vector stats_; std::vector node_mean_values_; + + bst_feature_t leaf_size_ {0}; + std::vector leaf_values_; + MultiTargetTreeNodeStat multi_target_stats_; + // allocate a new node, // !!!!!! NOTE: may cause BUG here, nodes.resize int AllocNode() { diff --git a/include/xgboost/tree_updater.h b/include/xgboost/tree_updater.h index a091c81b045b..09b444250d38 100644 --- a/include/xgboost/tree_updater.h +++ b/include/xgboost/tree_updater.h @@ -30,7 +30,7 @@ class Json; */ class TreeUpdater : public Configurable { protected: - GenericParameter const* tparam_; + GenericParameter const* tparam_ { nullptr }; public: /*! \brief virtual destructor */ @@ -82,16 +82,18 @@ class TreeUpdater : public Configurable { * \param name Name of the tree updater. * \param tparam A global runtime parameter */ - static TreeUpdater* Create(const std::string& name, GenericParameter const* tparam); + static TreeUpdater* Create(const std::string& name, GenericParameter const* tparam, + LearnerModelParam const* mparam); }; /*! * \brief Registry entry for tree updater. */ struct TreeUpdaterReg - : public dmlc::FunctionRegEntryBase > { -}; + : public dmlc::FunctionRegEntryBase< + TreeUpdaterReg, + std::function> {}; /*! * \brief Macro to register tree updater. diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 079e916c3260..48a1ad2e0014 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -507,17 +507,32 @@ def set_uint_info(self, field, data): def set_interface_info(self, field, data): """Set info type property into DMatrix.""" # If we are passed a dataframe, extract the series - if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'): + if lazy_isinstance(data, 'cupy.core.core', 'ndarray'): + interface = [data.__cuda_array_interface__] + device = data.device.id + elif lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'): if len(data.columns) != 1: raise ValueError( 'Expecting meta-info to contain a single column') data = data[data.columns[0]] - - interface = bytes(json.dumps([data.__cuda_array_interface__], - indent=2), 'utf-8') - _check_call(_LIB.XGDMatrixSetInfoFromInterface(self.handle, - c_str(field), - interface)) + interface = [data.__cuda_array_interface__] + import cupy # pylint: disable=import-error + device = cupy.cuda.runtime.pointerGetAttributes( + interface[0]['data'][0]).device + elif lazy_isinstance(data, 'cudf.core.series', 'Series'): + interface = [data.__cuda_array_interface__] + import cupy # pylint: disable=import-error + device = cupy.cuda.runtime.pointerGetAttributes( + interface[0]['data'][0]).device + else: + interface = [data.__array_interface__] + device = -1 + interface = bytes(json.dumps(interface), 'utf-8') + _check_call(_LIB.XGDMatrixSetInfoFromInterface( + self.handle, + c_str(field), + interface, + ctypes.c_int(device))) def save_binary(self, fname, silent=True): """Save DMatrix to an XGBoost buffer. Saved binary can be later loaded @@ -544,6 +559,10 @@ def set_label(self, label): """ if _has_cuda_array_interface(label): self.set_interface_info('label', label) + elif (hasattr(label, '__array_interface__') and + len(label.shape) == 2 and + label.shape[1] > 1): + self.set_interface_info('label', label) else: self.set_float_info('label', label) diff --git a/python-package/xgboost/experimental.py b/python-package/xgboost/experimental.py new file mode 100644 index 000000000000..73c47dfc74fd --- /dev/null +++ b/python-package/xgboost/experimental.py @@ -0,0 +1,21 @@ +'''Experimental features in XGBoost. Code in this module are not stable and +may be changed without notice in the future. + +''' +from .sklearn import xgboost_model_doc, XGBModel, XGBRegressorBase + + +@xgboost_model_doc( + 'scikit-learn API for XGBoost multi-target regression.', + ['estimators', 'model', 'objective']) +class XGBMultiRegressor(XGBModel, XGBRegressorBase): + # pylint: disable=missing-docstring + def __init__(self, + objective="reg:squarederror", + output_type='single', + tree_method='exact', + **kwargs): + super().__init__(objective=objective, + tree_method=tree_method, + output_type=output_type, + **kwargs) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index ef22f4309ed7..e58df11ced0e 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -118,9 +118,12 @@ def inner(preds, dmatrix): [2, 3, 4]], where each inner list is a group of indices of features that are allowed to interact with each other. See tutorial for more information - importance_type: string, default "gain" + importance_type : str, default "gain" The feature importance type for the feature_importances\\_ property: either "gain", "weight", "cover", "total_gain" or "total_cover". + tree_type : str, default 'single' + Either "single" or "multi". Used for different type of + multi-target regression. \\*\\*kwargs : dict, optional Keyword arguments for XGBoost Booster object. Full documentation of @@ -210,7 +213,7 @@ def __init__(self, max_depth=None, learning_rate=None, n_estimators=100, missing=np.nan, num_parallel_tree=None, monotone_constraints=None, interaction_constraints=None, importance_type="gain", gpu_id=None, - validate_parameters=None, **kwargs): + validate_parameters=None, output_type=None, **kwargs): if not SKLEARN_INSTALLED: raise XGBoostError( 'sklearn needs to be installed in order to use this module') @@ -242,6 +245,7 @@ def __init__(self, max_depth=None, learning_rate=None, n_estimators=100, self.interaction_constraints = interaction_constraints self.importance_type = importance_type self.gpu_id = gpu_id + self.output_type = output_type self.validate_parameters = validate_parameters def _more_tags(self): @@ -1023,7 +1027,7 @@ def get_num_boosting_rounds(self): @xgboost_model_doc( - "Implementation of the scikit-learn API for XGBoost regression.", + "scikit-learn API for XGBoost regression.", ['estimators', 'model', 'objective']) class XGBRegressor(XGBModel, XGBRegressorBase): # pylint: disable=missing-docstring diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 8431e424384a..de80ca4d4789 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -218,12 +218,13 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, } XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle, - char const* field, - char const* interface_c_str) { + char const *field, + char const *interface_c_str, + int32_t device) { API_BEGIN(); CHECK_HANDLE(); static_cast*>(handle) - ->get()->Info().SetInfo(field, interface_c_str); + ->get()->Info().SetInfo(field, interface_c_str, device); API_END(); } diff --git a/src/common/observer.h b/src/common/observer.h index 1af16d45dbd4..40660e85a30e 100644 --- a/src/common/observer.h +++ b/src/common/observer.h @@ -70,10 +70,10 @@ class TrainingObserver { OBSERVER_PRINT << "Procedure: " << name << OBSERVER_ENDL; for (size_t i = 0; i < h_vec.size(); ++i) { - OBSERVER_PRINT << h_vec[i] << ", "; - if (i % 8 == 0) { + if (i % 8 == 0 && i != 0) { OBSERVER_PRINT << OBSERVER_NEWLINE; } + OBSERVER_PRINT << h_vec[i] << ", "; if ((i + 1) == n) { break; } diff --git a/src/common/quantile.h b/src/common/quantile.h index c0079ff8ebc8..f4f7c4cd7a7f 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -564,7 +564,8 @@ class QuantileSketchTemplate { // check invariant size_t n = (1ULL << nlevel); CHECK(n * limit_size >= maxn) << "invalid init parameter"; - CHECK(nlevel <= std::max(static_cast(1), static_cast(limit_size * eps))) + CHECK(nlevel <= std::max(static_cast(1), + static_cast(limit_size * eps))) << "invalid init parameter"; } diff --git a/src/data/array_interface.h b/src/data/array_interface.h index c8abb2d450da..c2f7d9feba8f 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -87,7 +87,7 @@ struct ArrayInterfaceErrors { } } - static std::string UnSupportedType(const char (&typestr)[3]) { + static std::string UnSupportedType(char const typestr[3]) { return TypeStr(typestr[1]) + " is not supported."; } }; diff --git a/src/data/data.cc b/src/data/data.cc index f24753e31e6b..784bbd30d82c 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -11,8 +11,10 @@ #include "xgboost/host_device_vector.h" #include "xgboost/logging.h" #include "xgboost/version_config.h" +#include "xgboost/json.h" #include "sparse_page_writer.h" #include "simple_dmatrix.h" +#include "array_interface.h" #include "../common/io.h" #include "../common/math.h" @@ -434,11 +436,40 @@ void MetaInfo::Validate(int32_t device) const { } } + +void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str, int32_t device) { + if (device == GenericParameter::kCpuId) { + Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()}); + auto const& arr = get(j_interface); + CHECK_EQ(arr.size(), 1) << "Columnar format array interface for CPU is not supported."; + auto obj = get(arr.at(0)); + ArrayInterface interface {obj}; + CHECK_EQ(c_key, std::string{"label"}) + << "Only labels is supported for setting meta info with array " + "interface."; + if (interface.num_cols != 1) { + LOG(WARNING) << "Found 2-d array label. Multi target support is at " + "experimental stage " + "and is not recommended for real world usage."; + } + + auto& h_labels = labels_.HostVector(); + h_labels.resize(interface.num_rows * interface.num_cols); +#pragma omp parallel for schedule(static) + for (omp_ulong i = 0; i < h_labels.size(); ++i) { + h_labels[i] = interface.GetElement(i); + } + labels_rows = interface.num_rows; + labels_cols = interface.num_cols; + } else { #if !defined(XGBOOST_USE_CUDA) -void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { common::AssertGPUSupport(); -} + LOG(FATAL) << "XGBoost version not compiled with GPU support."; +#else + this->SetInfoDevice(c_key, interface_str); #endif // !defined(XGBOOST_USE_CUDA) + } +} DMatrix* DMatrix::Load(const std::string& uri, bool silent, diff --git a/src/data/data.cu b/src/data/data.cu index fb57f4751545..2c41e2a9e955 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -58,7 +58,7 @@ void CopyGroupInfoImpl(ArrayInterface column, std::vector* out) { std::partial_sum(out->begin(), out->end(), out->begin()); } -void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { +void MetaInfo::SetInfoDevice(const char * c_key, std::string const& interface_str) { Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()}); auto const& j_arr = get(j_interface); CHECK_EQ(j_arr.size(), 1) diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index e554f8a559e0..8941028394b2 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -18,7 +18,7 @@ #include "xgboost/predictor.h" #include "xgboost/linear_updater.h" #include "xgboost/logging.h" -#include "xgboost/learner.h" +#include "xgboost/model.h" #include "gblinear_model.h" #include "../common/timer.h" diff --git a/src/gbm/gblinear_model.h b/src/gbm/gblinear_model.h index f2d0d9a868d3..b6d4a131805f 100644 --- a/src/gbm/gblinear_model.h +++ b/src/gbm/gblinear_model.h @@ -4,7 +4,6 @@ #pragma once #include #include -#include #include #include diff --git a/src/gbm/gbm.cc b/src/gbm/gbm.cc index 87a6ded29042..3c6936f37ce1 100644 --- a/src/gbm/gbm.cc +++ b/src/gbm/gbm.cc @@ -9,7 +9,7 @@ #include #include "xgboost/gbm.h" -#include "xgboost/learner.h" +#include "xgboost/model.h" #include "xgboost/generic_parameters.h" namespace dmlc { diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 61a3021cbefd..cb3d0c23f5ef 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -121,6 +121,7 @@ void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) { return; } // tparam_ is set before calling this function. + CHECK(tparam_.GetInitialised()); if (tparam_.tree_method != TreeMethod::kAuto) { return; } @@ -162,13 +163,11 @@ void GBTree::ConfigureUpdaters() { case TreeMethod::kApprox: tparam_.updater_seq = "grow_histmaker,prune"; break; - case TreeMethod::kExact: - tparam_.updater_seq = "grow_colmaker,prune"; + case TreeMethod::kExact: { + tparam_.updater_seq = "grow_colmaker"; break; + } case TreeMethod::kHist: - LOG(INFO) << - "Tree method is selected to be 'hist', which uses a " - "single updater grow_quantile_histmaker."; tparam_.updater_seq = "grow_quantile_histmaker"; break; case TreeMethod::kGPUHist: { @@ -201,6 +200,8 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector tmp(in_gpair->Size() / ngroup, GradientPair(), in_gpair->DeviceIdx()); + CHECK(model_.learner_model_param->output_type == OutputType::kSingle) + << "Using one tree per-target should not choose multi-target tree."; const auto& gpair_h = in_gpair->ConstHostVector(); auto nsize = static_cast(tmp.Size()); for (int gid = 0; gid < ngroup; ++gid) { @@ -252,7 +253,8 @@ void GBTree::InitUpdater(Args const& cfg) { // create new updaters for (const std::string& pstr : ups) { - std::unique_ptr up(TreeUpdater::Create(pstr.c_str(), generic_param_)); + std::unique_ptr up(TreeUpdater::Create(pstr.c_str(), generic_param_, + model_.learner_model_param)); up->Configure(cfg); updaters_.push_back(std::move(up)); } @@ -273,7 +275,12 @@ void GBTree::BoostNewTrees(HostDeviceVector* gpair, << "Set `process_type` to `update` if you want to update existing " "trees."; // create new tree - std::unique_ptr ptr(new RegTree()); + std::unique_ptr ptr; + if (model_.learner_model_param->output_type == OutputType::kSingle) { + ptr.reset(new RegTree(1, RegTree::kSingle)); + } else { + ptr.reset(new RegTree(model_.learner_model_param->num_targets, RegTree::kMulti)); + } ptr->param.UpdateAllowUnknown(this->cfg_); new_trees.push_back(ptr.get()); ret->push_back(std::move(ptr)); @@ -348,7 +355,9 @@ void GBTree::LoadConfig(Json const& in) { auto const& j_updaters = get(in["updater"]); updaters_.clear(); for (auto const& kv : j_updaters) { - std::unique_ptr up(TreeUpdater::Create(kv.first, generic_param_)); + CHECK(model_.learner_model_param); + std::unique_ptr up(TreeUpdater::Create(kv.first, generic_param_, + model_.learner_model_param)); up->LoadConfig(kv.second); updaters_.push_back(std::move(up)); } @@ -681,7 +690,7 @@ class Dart : public GBTree { bool drop = std::binary_search(idx_drop_.begin(), idx_drop_.end(), i); if (!drop) { int tid = model_.trees[i]->GetLeafIndex(*p_feats); - psum += weight_drop_[i] * (*model_.trees[i])[tid].LeafValue(); + psum += weight_drop_[i] * (*model_.trees[i]).LeafValue(tid); } } } diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 534c3ad5469a..1a60a4a1df49 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -71,6 +71,8 @@ struct GBTreeTrainParam : public XGBoostParameter { PredictorType predictor; // tree construction method TreeMethod tree_method; + /*! \brief size of leaf vector needed in tree */ + RegTree::TreeKind tree_type; // declare parameters DMLC_DECLARE_PARAMETER(GBTreeTrainParam) { DMLC_DECLARE_FIELD(num_parallel_tree) @@ -79,7 +81,7 @@ struct GBTreeTrainParam : public XGBoostParameter { .describe("Number of parallel trees constructed during each iteration."\ " This option is used to support boosted random forest."); DMLC_DECLARE_FIELD(updater_seq) - .set_default("grow_colmaker,prune") + .set_default("grow_colmaker") .describe("Tree updater sequence."); DMLC_DECLARE_FIELD(process_type) .set_default(TreeProcessType::kDefault) @@ -102,6 +104,11 @@ struct GBTreeTrainParam : public XGBoostParameter { .add_enum("hist", TreeMethod::kHist) .add_enum("gpu_hist", TreeMethod::kGPUHist) .describe("Choice of tree construction method."); + DMLC_DECLARE_FIELD(tree_type) + .add_enum("single", RegTree::TreeKind::kSingle) + .add_enum("multi", RegTree::TreeKind::kMulti) + .set_default(RegTree::TreeKind::kSingle) + .describe("Type of tree."); } }; diff --git a/src/gbm/gbtree_model.h b/src/gbm/gbtree_model.h index 7ac7d8f470a2..a2b0d4a0a99c 100644 --- a/src/gbm/gbtree_model.h +++ b/src/gbm/gbtree_model.h @@ -7,16 +7,18 @@ #include #include -#include -#include -#include -#include #include #include #include #include +#include "xgboost/tree_model.h" +#include "xgboost/parameter.h" +#include "xgboost/model.h" + +DECLARE_FIELD_ENUM_CLASS(xgboost::RegTree::TreeKind); + namespace xgboost { class Json; @@ -55,7 +57,7 @@ struct GBTreeModelParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(num_trees) .set_lower_bound(0) .set_default(0) - .describe("Number of features used for training and prediction."); + .describe("Number of trees used for training and prediction."); DMLC_DECLARE_FIELD(size_leaf_vector) .set_lower_bound(0) .set_default(0) diff --git a/src/learner.cc b/src/learner.cc index 34649480c5ce..26ff551ac0e5 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -56,6 +56,7 @@ enum class DataSplitMode : int { } // namespace xgboost DECLARE_FIELD_ENUM_CLASS(xgboost::DataSplitMode); +DECLARE_FIELD_ENUM_CLASS(xgboost::OutputType); namespace xgboost { // implementation of base learner. @@ -85,8 +86,11 @@ struct LearnerModelParamLegacy : public dmlc::Parameter /*! \brief the version of XGBoost. */ uint32_t major_version; uint32_t minor_version; + /*! \brief Used for multi target trees. */ + uint32_t num_targets; + OutputType output_type; /*! \brief reserved field */ - int reserved[27]; + int reserved[25]; /*! \brief constructor */ LearnerModelParamLegacy() { std::memset(this, 0, sizeof(LearnerModelParamLegacy)); @@ -102,6 +106,12 @@ struct LearnerModelParamLegacy : public dmlc::Parameter obj["base_score"] = std::to_string(base_score); obj["num_feature"] = std::to_string(num_feature); obj["num_class"] = std::to_string(num_class); + obj["num_targets"] = std::to_string(num_targets); + if (output_type == OutputType::kSingle) { + obj["output_type"] = String("single"); + } else { + obj["output_type"] = String("multi"); + } return Json(std::move(obj)); } void FromJson(Json const& obj) { @@ -110,6 +120,14 @@ struct LearnerModelParamLegacy : public dmlc::Parameter m["base_score"] = get(j_param.at("base_score")); m["num_feature"] = get(j_param.at("num_feature")); m["num_class"] = get(j_param.at("num_class")); + + if (j_param.find("num_targets") == j_param.cend()) { + LOG(WARNING) << "Using old experimental JSON model. Please consider " + "saving the model again in current version of XGBoost."; + } else { + m["num_targets"] = get(j_param.at("num_targets")); + m["output_type"] = get(j_param.at("output_type")); + } this->Init(m); } // declare parameters @@ -122,6 +140,11 @@ struct LearnerModelParamLegacy : public dmlc::Parameter .describe( "Number of features in training data," " this parameter will be automatically detected by learner."); + DMLC_DECLARE_FIELD(num_targets).set_default(0); + DMLC_DECLARE_FIELD(output_type) + .set_default(OutputType::kSingle) + .add_enum("single", OutputType::kSingle) + .add_enum("multi", OutputType::kMulti); DMLC_DECLARE_FIELD(num_class).set_default(0).set_lower_bound(0).describe( "Number of class option for multi-class classifier. " " By default equals 0 and corresponds to binary classifier."); @@ -131,10 +154,18 @@ struct LearnerModelParamLegacy : public dmlc::Parameter LearnerModelParam::LearnerModelParam( LearnerModelParamLegacy const &user_param, float base_margin) : base_score{base_margin}, num_feature{user_param.num_feature}, - num_output_group{user_param.num_class == 0 - ? 1 - : static_cast(user_param.num_class)} -{} + num_targets{user_param.num_targets}, output_type{user_param.output_type} +{ + if (user_param.output_type == OutputType::kSingle) { + CHECK(user_param.num_class == 0 || user_param.num_targets == 0); + num_output_group = std::max(static_cast(user_param.num_class), + user_param.num_targets); + num_targets = 1; + } else { + num_targets = std::max(num_targets, static_cast(user_param.num_class)); + } + num_output_group = std::max(num_output_group, 1u); +} struct LearnerTrainParam : public XGBoostParameter { // data split mode, can be row, col, or none. @@ -980,8 +1011,9 @@ class LearnerImpl : public LearnerIO { auto &out = output_predictions_.Cache(m, generic_parameters_.gpu_id).predictions; out.Resize(predt.predictions.Size()); out.Copy(predt.predictions); - + TrainingObserver::Instance().Observe(out, "Before Transform"); obj_->EvalTransform(&out); + TrainingObserver::Instance().Observe(out, "Eval Transform"); for (auto& ev : metrics_) { os << '\t' << data_names[i] << '-' << ev->Name() << ':' << ev->Eval(out, m->Info(), tparam_.dsplit == DataSplitMode::kRow); @@ -1075,6 +1107,16 @@ class LearnerImpl : public LearnerIO { CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_) << "Number of columns does not match number of features in booster."; } + + if (learner_model_param_.output_type == OutputType::kSingle) { + CHECK(p_fmat->Info().labels_cols == 1 || + p_fmat->Info().labels_cols == learner_model_param_.num_output_group); + } else { + CHECK(p_fmat->Info().labels_cols == learner_model_param_.num_targets || + p_fmat->Info().labels_cols == learner_model_param_.num_output_group) + << "p_fmat->Info().labels_cols: " << p_fmat->Info().labels_cols << ", " + << "learner_model_param_.num_targets: " << learner_model_param_.num_targets; + } } private: diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index 377a05010ea3..474aa2e7a4bf 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -41,25 +41,23 @@ class MultiClassMetricsReduction { const HostDeviceVector& weights, const HostDeviceVector& labels, const HostDeviceVector& preds, + size_t ndata, const size_t n_class) const { - size_t ndata = labels.Size(); - - const auto& h_labels = labels.HostVector(); - const auto& h_weights = weights.HostVector(); - const auto& h_preds = preds.HostVector(); + const auto h_labels = labels.HostSpan(); + const auto h_weights = weights.HostSpan(); + const auto h_preds = preds.HostSpan(); bst_float residue_sum = 0; bst_float weights_sum = 0; int label_error = 0; bool const is_null_weight = weights.Size() == 0; - #pragma omp parallel for reduction(+: residue_sum, weights_sum) schedule(static) for (omp_ulong idx = 0; idx < ndata; ++idx) { bst_float weight = is_null_weight ? 1.0f : h_weights[idx]; auto label = static_cast(h_labels[idx]); if (label >= 0 && label < static_cast(n_class)) { residue_sum += EvalRowPolicy::EvalRow( - label, h_preds.data() + idx * n_class, n_class) * weight; + h_labels, h_preds, idx, n_class) * weight; weights_sum += weight; } else { label_error = label; @@ -77,9 +75,8 @@ class MultiClassMetricsReduction { const HostDeviceVector& weights, const HostDeviceVector& labels, const HostDeviceVector& preds, + size_t const n_data, const size_t n_class) { - size_t n_data = labels.Size(); - thrust::counting_iterator begin(0); thrust::counting_iterator end = begin + n_data; @@ -101,7 +98,7 @@ class MultiClassMetricsReduction { auto label = static_cast(s_labels[idx]); if (label >= 0 && label < static_cast(n_class)) { residue = EvalRowPolicy::EvalRow( - label, &s_preds[idx * n_class], n_class) * weight; + s_labels, s_preds, idx, n_class) * weight; } else { s_label_error[0] = label; } @@ -120,13 +117,14 @@ class MultiClassMetricsReduction { const GenericParameter &tparam, int device, size_t n_class, - const HostDeviceVector& weights, - const HostDeviceVector& labels, + MetaInfo const& info, const HostDeviceVector& preds) { PackedReduceResult result; + auto const& labels = info.labels_; + auto const& weights = info.weights_; if (device < 0) { - result = CpuReduceMetrics(weights, labels, preds, n_class); + result = CpuReduceMetrics(weights, labels, preds, info.num_row_, n_class); } #if defined(XGBOOST_USE_CUDA) else { // NOLINT @@ -136,7 +134,7 @@ class MultiClassMetricsReduction { weights.SetDevice(device_); dh::safe_cuda(cudaSetDevice(device_)); - result = DeviceReduceMetrics(weights, labels, preds, n_class); + result = DeviceReduceMetrics(weights, labels, preds, info.num_row_, n_class); } #endif // defined(XGBOOST_USE_CUDA) return result; @@ -161,13 +159,13 @@ struct EvalMClassBase : public Metric { CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; CHECK(preds.Size() % info.labels_.Size() == 0) << "label and prediction size not match"; - const size_t nclass = preds.Size() / info.labels_.Size(); + const size_t nclass = preds.Size() / info.num_row_; CHECK_GE(nclass, 1U) << "mlogloss and merror are only used for multi-class classification," << " use logloss for binary classification"; int device = tparam_->gpu_id; - auto result = reducer_.Reduce(*tparam_, device, nclass, info.weights_, info.labels_, preds); + auto result = reducer_.Reduce(*tparam_, device, nclass, info, preds); double dat[2] { result.Residue(), result.Weights() }; if (distributed) { @@ -178,12 +176,14 @@ struct EvalMClassBase : public Metric { /*! * \brief to be implemented by subclass, * get evaluation result from one row - * \param label label of current instance - * \param pred prediction value of current instance + * \param s_labels label of current instance + * \param s_predt prediction value of current instance + * \param idx index of current instance * \param nclass number of class in the prediction */ - XGBOOST_DEVICE static bst_float EvalRow(int label, - const bst_float *pred, + XGBOOST_DEVICE static bst_float EvalRow(common::Span s_label, + common::Span s_predt, + size_t idx, size_t nclass); /*! * \brief to be overridden by subclass, final transformation @@ -205,10 +205,15 @@ struct EvalMatchError : public EvalMClassBase { const char* Name() const override { return "merror"; } - XGBOOST_DEVICE static bst_float EvalRow(int label, - const bst_float *pred, + XGBOOST_DEVICE static bst_float EvalRow(common::Span s_label, + common::Span s_predt, + size_t idx, size_t nclass) { - return common::FindMaxIndex(pred, pred + nclass) != pred + static_cast(label); + auto pred = s_predt.subspan(idx * nclass, nclass); + auto label = static_cast(s_label[idx]); + return + common::FindMaxIndex(pred.begin(), pred.begin() + nclass) != + pred.begin() + static_cast(label); } }; @@ -217,19 +222,47 @@ struct EvalMultiLogLoss : public EvalMClassBase { const char* Name() const override { return "mlogloss"; } - XGBOOST_DEVICE static bst_float EvalRow(int label, - const bst_float *pred, + XGBOOST_DEVICE static bst_float EvalRow(common::Span s_label, + common::Span s_predt, + size_t idx, size_t nclass) { + auto pred = s_predt.subspan(idx * nclass, nclass); + auto label = static_cast(s_label[idx]); const bst_float eps = 1e-16f; - auto k = static_cast(label); - if (pred[k] > eps) { - return -std::log(pred[k]); + if (pred[label] > eps) { + return -std::log(pred[label]); } else { return -std::log(eps); } } }; +struct EvalMultiLogLossOneHot : public EvalMClassBase { + const char* Name() const override { + return "mtlogloss"; + } + XGBOOST_DEVICE static bst_float EvalRow(common::Span s_label, + common::Span s_predt, + size_t idx, + size_t nclass) { + auto predt = s_predt.subspan(idx * nclass, nclass); + auto label = s_label.subspan(idx * nclass, nclass); + size_t k = 0; + for (; k < nclass; ++k) { + if (label[k] == 1) { + break; + } + } + float ret { 0 }; + if (predt[k] > kRtEps) { + ret = -std::log(predt[k]); + } else { + ret = -std::log(kRtEps); + } + return ret; + } +}; + XGBOOST_REGISTER_METRIC(MatchError, "merror") .describe("Multiclass classification error.") .set_body([](const char* param) { return new EvalMatchError(); }); @@ -237,5 +270,9 @@ XGBOOST_REGISTER_METRIC(MatchError, "merror") XGBOOST_REGISTER_METRIC(MultiLogLoss, "mlogloss") .describe("Multiclass negative loglikelihood.") .set_body([](const char* param) { return new EvalMultiLogLoss(); }); + +XGBOOST_REGISTER_METRIC(MultiLogLossOneHot, "mtlogloss") +.describe("Multiclass negative loglikelihood.") +.set_body([](const char* param) { return new EvalMultiLogLossOneHot(); }); } // namespace metric } // namespace xgboost diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 1c84f4947514..63d15800da31 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -25,21 +25,28 @@ namespace predictor { DMLC_REGISTRY_FILE_TAG(cpu_predictor); -bst_float PredValue(const SparsePage::Inst &inst, - const std::vector> &trees, - const std::vector &tree_info, int bst_group, - RegTree::FVec *p_feats, unsigned tree_begin, - unsigned tree_end) { - bst_float psum = 0.0f; - p_feats->Fill(inst); - for (size_t i = tree_begin; i < tree_end; ++i) { - if (tree_info[i] == bst_group) { - int tid = trees[i]->GetLeafIndex(*p_feats); - psum += (*trees[i])[tid].LeafValue(); +int GetNext(std::unique_ptr const& tree, + int pid, bst_float fvalue, bool is_unknown) { + bst_float split_value = (*tree)[pid].SplitCond(); + if (is_unknown) { + return (*tree)[pid].DefaultChild(); + } else { + if (fvalue < split_value) { + return (*tree)[pid].LeftChild(); + } else { + return (*tree)[pid].RightChild(); } } - p_feats->Drop(inst); - return psum; +} + +int GetLeafIndex(std::unique_ptr const& tree, + const RegTree::FVec& feat) { + bst_node_t nid = 0; + while (!(*tree)[nid].IsLeaf()) { + unsigned split_index = (*tree)[nid].SplitIndex(); + nid = GetNext(tree, nid, feat.GetFvalue(split_index), feat.IsMissing(split_index)); + } + return nid; } template @@ -101,6 +108,23 @@ class AdapterView { bst_row_t const static base_rowid = 0; // NOLINT }; +bst_float PredValue(const SparsePage::Inst &inst, + const std::vector> &trees, + const std::vector &tree_info, int bst_group, + RegTree::FVec *p_feats, unsigned tree_begin, + unsigned tree_end) { + bst_float psum = 0.0f; + p_feats->Fill(inst); + for (size_t i = tree_begin; i < tree_end; ++i) { + if (tree_info[i] == bst_group) { + int tid = trees[i]->GetLeafIndex(*p_feats); + psum += (*trees[i]).LeafValue(tid); + } + } + p_feats->Drop(inst); + return psum; +} + template void PredictBatchKernel(DataView batch, std::vector *out_preds, gbm::GBTreeModel const &model, int32_t tree_begin, @@ -163,6 +187,53 @@ class CPUPredictor : public Predictor { } } + void + PredictVectorValue(const SparsePage::Inst &inst, + const std::vector> &trees, + RegTree::FVec *p_feats, + unsigned tree_begin, unsigned tree_end, common::Span out) { + CHECK_EQ(out.size(), trees[0]->LeafSize()); + p_feats->Fill(inst); + for (size_t i = tree_begin; i < tree_end; ++i) { + // no group id for vector leaf. + auto const& tree = trees[i]; + int tid = GetLeafIndex(tree, *p_feats); + auto vl = tree->VectorLeafValue(tid); + for (size_t j = 0; j < out.size(); ++j) { + out[j] += vl[j]; + } + } + p_feats->Drop(inst); + } + + void PredictVectorInternal(SparsePage const& page, gbm::GBTreeModel const &model, + std::vector *out_preds, + uint32_t tree_begin, uint32_t tree_end) { + std::lock_guard guard(lock_); + const int threads = omp_get_max_threads(); + InitThreadTemp(threads, model.learner_model_param->num_feature, &this->thread_temp_); + page.data.HostVector(); + page.offset.HostVector(); + dmlc::OMPException omp_handler; + size_t targets = model.learner_model_param->num_targets; +#pragma omp parallel for + for (omp_ulong i = 0; i < page.Size(); ++i) { + omp_handler.Run( + [this, &page, &model, tree_begin, tree_end, out_preds, + targets](omp_ulong i) { + auto inst = page[i]; + const int tid = omp_get_thread_num(); + RegTree::FVec &feats = thread_temp_[tid]; + size_t offset = targets * i; + auto out = common::Span(*out_preds).subspan(offset, targets); + this->PredictVectorValue(inst, model.trees, &feats, tree_begin, + tree_end, out); + }, + i); + } + omp_handler.Rethrow(); + } + void PredictDMatrix(DMatrix *p_fmat, std::vector *out_preds, gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) { @@ -182,29 +253,19 @@ class CPUPredictor : public Predictor { HostDeviceVector* out_preds, const gbm::GBTreeModel& model) const { CHECK_NE(model.learner_model_param->num_output_group, 0); - size_t n = model.learner_model_param->num_output_group * info.num_row_; + size_t n = std::max(model.learner_model_param->num_targets, + model.learner_model_param->num_output_group) * + info.num_row_; const auto& base_margin = info.base_margin_.HostVector(); - out_preds->Resize(n); std::vector& out_preds_h = out_preds->HostVector(); - if (base_margin.size() == n) { - CHECK_EQ(out_preds->Size(), n); - std::copy(base_margin.begin(), base_margin.end(), out_preds_h.begin()); + // size_t const out_size = info.labels_cols * n; + out_preds_h.resize(n); + + if (base_margin.size() != 0) { + CHECK_EQ(base_margin.size(), out_preds_h.size()) + << "Size of base margin must equal to length of prediction."; + std::copy(base_margin.cbegin(), base_margin.cend(), out_preds_h.begin()); } else { - if (!base_margin.empty()) { - std::ostringstream oss; - oss << "Ignoring the base margin, since it has incorrect length. " - << "The base margin must be an array of length "; - if (model.learner_model_param->num_output_group > 1) { - oss << "[num_class] * [number of data points], i.e. " - << model.learner_model_param->num_output_group << " * " << info.num_row_ - << " = " << n << ". "; - } else { - oss << "[number of data points], i.e. " << info.num_row_ << ". "; - } - oss << "Instead, all data points will use " - << "base_score = " << model.learner_model_param->base_score; - LOG(WARNING) << oss.str(); - } std::fill(out_preds_h.begin(), out_preds_h.end(), model.learner_model_param->base_score); } @@ -251,9 +312,18 @@ class CPUPredictor : public Predictor { CHECK_LE(beg_version, end_version); if (beg_version < end_version) { - this->PredictDMatrix(dmat, &out_preds->HostVector(), model, - beg_version * output_groups, - end_version * output_groups); + if (model.trees.front()->Kind() == RegTree::kMulti) { + CHECK_EQ(output_groups, 1); + for (auto const& page : dmat->GetBatches()) { + this->PredictVectorInternal(page, model, &out_preds->HostVector(), + beg_version * output_groups, + end_version * output_groups); + } + } else { + this->PredictDMatrix(dmat, &out_preds->HostVector(), model, + beg_version * output_groups, + end_version * output_groups); + } } // delta means {size of forest} * {number of newly accumulated layers} @@ -262,7 +332,8 @@ class CPUPredictor : public Predictor { predts->Update(delta); CHECK(out_preds->Size() == output_groups * dmat->Info().num_row_ || - out_preds->Size() == dmat->Info().num_row_); + out_preds->Size() == + model.learner_model_param->num_targets * dmat->Info().num_row_); } template @@ -315,6 +386,11 @@ class CPUPredictor : public Predictor { } out_preds->resize(model.learner_model_param->num_output_group * (model.param.size_leaf_vector + 1)); + if (model.trees.size() == 0) { + return; + } + out_preds->resize(model.learner_model_param->num_output_group * + (model.trees.front()->LeafSize() + 1)); // loop over output groups for (uint32_t gid = 0; gid < model.learner_model_param->num_output_group; ++gid) { (*out_preds)[gid] = PredValue(inst, model.trees, model.tree_info, gid, diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 9a498136d41a..0498ea7f06f1 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -175,7 +175,7 @@ __device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree, } } } - return n.LeafValue(); + return n.SinlgeLeafValue(); } template diff --git a/src/tree/param.h b/src/tree/param.h index 280f06066e44..1ff6536799b0 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -272,7 +272,7 @@ XGBOOST_DEVICE inline T CalcGainGivenWeight(const TrainingParams &p, // calculate the cost of loss function template -XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess) { +XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T const& sum_grad, T const& sum_hess) { if (sum_hess < p.min_child_weight) { return T(0.0); } @@ -295,8 +295,8 @@ XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess } template -XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, StatT stat) { + typename StatT, typename T = typename std::remove_reference< decltype(StatT().GetHess()) >::type> +XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, StatT const& stat) { return CalcGain(p, stat.GetGrad(), stat.GetHess()); } @@ -453,6 +453,7 @@ struct SplitEntryContainer { return false; } } + /*! * \brief update the split entry, replace it if e is better * \param new_loss_chg loss reduction of new candidate diff --git a/src/tree/split_evaluator.cc b/src/tree/split_evaluator.cc index be166156b004..93979c79796e 100644 --- a/src/tree/split_evaluator.cc +++ b/src/tree/split_evaluator.cc @@ -245,6 +245,7 @@ class MonotonicConstraint final : public SplitEvaluator { lower_[leftid] = lower_.at(nodeid); lower_[rightid] = lower_.at(nodeid); + // when value is lesser than split_cond, data is assigned to left. if (constraint < 0) { lower_[leftid] = mid; upper_[rightid] = mid; diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 8f45621ca15e..99e68f41d88a 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -188,7 +188,7 @@ class TextGenerator : public TreeGenerator { kLeafTemplate, {{"{tabs}", SuperT::Tabs(depth)}, {"{nid}", std::to_string(nid)}, - {"{leaf}", SuperT::ToStr(tree[nid].LeafValue())}, + {"{leaf}", SuperT::ToStr(tree.LeafValue(nid))}, {"{stats}", with_stats_ ? SuperT::Match(kStatTemplate, {{"{cover}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}}); @@ -313,7 +313,7 @@ class JsonGenerator : public TreeGenerator { std::string result = SuperT::Match( kLeafTemplate, {{"{nid}", std::to_string(nid)}, - {"{leaf}", SuperT::ToStr(tree[nid].LeafValue())}, + {"{leaf}", SuperT::ToStr(tree.LeafValue(nid))}, {"{stat}", with_stats_ ? SuperT::Match( kStatTemplate, {{"{sum_hess}", @@ -569,7 +569,7 @@ class GraphvizGenerator : public TreeGenerator { " {nid} [ label=\"leaf={leaf-value}\" {params}]\n"; auto result = SuperT::Match(kLeafTemplate, { {"{nid}", std::to_string(nid)}, - {"{leaf-value}", ToStr(tree[nid].LeafValue())}, + {"{leaf-value}", ToStr(tree.LeafValue(nid))}, {"{params}", param_.leaf_node_params}}); return result; }; @@ -613,6 +613,8 @@ constexpr bst_node_t RegTree::kRoot; std::string RegTree::DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const { + CHECK_EQ(Kind(), kSingle) + << "Dump model is not available for multi-target tree."; std::unique_ptr builder { TreeGenerator::Create(format, fmap, with_stats) }; @@ -623,6 +625,7 @@ std::string RegTree::DumpModel(const FeatureMap& fmap, } bool RegTree::Equal(const RegTree& b) const { + CHECK_EQ(Kind(), kSingle); if (NumExtraNodes() != b.NumExtraNodes()) { return false; } @@ -663,6 +666,7 @@ bst_node_t RegTree::GetNumSplitNodes() const { } void RegTree::Load(dmlc::Stream* fi) { + CHECK_NE(Kind(), kMulti) << "Multi-target tree requires JSON serialization format."; CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam)); nodes_.resize(param.num_nodes); stats_.resize(param.num_nodes); @@ -681,6 +685,7 @@ void RegTree::Load(dmlc::Stream* fi) { CHECK_EQ(static_cast(deleted_nodes_.size()), param.num_deleted); } void RegTree::Save(dmlc::Stream* fo) const { + CHECK_NE(Kind(), kMulti) << "Model persistent for multi-target tree is not yet implemented."; CHECK_EQ(param.num_nodes, static_cast(nodes_.size())); CHECK_EQ(param.num_nodes, static_cast(stats_.size())); fo->Write(¶m, sizeof(TreeParam)); @@ -756,6 +761,7 @@ void RegTree::LoadModel(Json const& in) { } void RegTree::SaveModel(Json* p_out) const { + CHECK_NE(Kind(), kMulti) << "Model persistent for multi-target tree is not yet implemented."; auto& out = *p_out; CHECK_EQ(param.num_nodes, static_cast(nodes_.size())); CHECK_EQ(param.num_nodes, static_cast(stats_.size())); @@ -820,7 +826,7 @@ bst_float RegTree::FillNodeMeanValue(int nid) { bst_float result; auto& node = (*this)[nid]; if (node.IsLeaf()) { - result = node.LeafValue(); + result = LeafValue(nid); } else { result = this->FillNodeMeanValue(node.LeftChild()) * this->Stat(node.LeftChild()).sum_hess; result += this->FillNodeMeanValue(node.RightChild()) * this->Stat(node.RightChild()).sum_hess; @@ -832,6 +838,7 @@ bst_float RegTree::FillNodeMeanValue(int nid) { void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat, bst_float *out_contribs) const { + CHECK_EQ(Kind(), kSingle) << "Contribution is not available for mutli-target tree."; CHECK_GT(this->node_mean_values_.size(), 0U); // this follows the idea of http://blog.datadive.net/interpreting-random-forests/ unsigned split_index = 0; @@ -851,7 +858,7 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat, out_contribs[split_index] += new_value - node_value; node_value = new_value; } - bst_float leaf_value = (*this)[nid].LeafValue(); + bst_float leaf_value = this->LeafValue(nid); // update leaf feature weight out_contribs[split_index] += leaf_value - node_value; } @@ -947,6 +954,7 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi, bst_float parent_one_fraction, int parent_feature_index, int condition, unsigned condition_feature, bst_float condition_fraction) const { + CHECK_EQ(Kind(), kSingle) << "Tree shap is not available for mutli-target tree."; const auto node = (*this)[node_index]; // stop if we have no weight coming down to us @@ -968,7 +976,7 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi, const bst_float w = UnwoundPathSum(unique_path, unique_depth, i); const PathElement &el = unique_path[i]; phi[el.feature_index] += w * (el.one_fraction - el.zero_fraction) - * node.LeafValue() * condition_fraction; + * LeafValue(node_index) * condition_fraction; } // internal node diff --git a/src/tree/tree_updater.cc b/src/tree/tree_updater.cc index 24ff2ab92968..1b2e7e87eb8c 100644 --- a/src/tree/tree_updater.cc +++ b/src/tree/tree_updater.cc @@ -14,13 +14,16 @@ DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg); namespace xgboost { -TreeUpdater* TreeUpdater::Create(const std::string& name, GenericParameter const* tparam) { +TreeUpdater* TreeUpdater::Create(const std::string& name, GenericParameter const* tparam, + LearnerModelParam const* mparam) { auto *e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name); if (e == nullptr) { LOG(FATAL) << "Unknown tree updater " << name; } - auto p_updater = (e->body)(); - p_updater->tparam_ = tparam; + auto p_updater = (e->body)(tparam, mparam); + if (!p_updater->tparam_) { + p_updater->tparam_ = tparam; + } return p_updater; } diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 951cfdb5ec27..f358fc6d738c 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -9,7 +9,9 @@ #include #include #include +#include +#include "xgboost/base.h" #include "xgboost/parameter.h" #include "xgboost/tree_updater.h" #include "xgboost/logging.h" @@ -17,6 +19,7 @@ #include "param.h" #include "constraints.h" #include "../common/random.h" +#include "../common/timer.h" #include "split_evaluator.h" namespace xgboost { @@ -49,577 +52,558 @@ struct ColMakerTrainParam : XGBoostParameter { DMLC_REGISTER_PARAMETER(ColMakerTrainParam); -/*! \brief column-wise update to construct a tree */ -class ColMaker: public TreeUpdater { - public: - void Configure(const Args& args) override { - param_.UpdateAllowUnknown(args); - colmaker_param_.UpdateAllowUnknown(args); - if (!spliteval_) { - spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator)); - } - spliteval_->Init(¶m_); +namespace { +struct ThreadEntry { + /*! \brief statistics of data */ + GradStats stats; + /*! \brief last feature value scanned */ + bst_float last_fvalue { 0 }; + /*! \brief current best solution */ + SplitEntry best; + // constructor + ThreadEntry() = default; +}; +struct NodeEntry { + /*! \brief statics for node entry */ + GradStats stats; + /*! \brief loss of this node, without split */ + bst_float root_gain { 0.0f }; + /*! \brief weight calculated related to current data */ + bst_float weight { 0.0f }; + /*! \brief current best solution */ + SplitEntry best; + friend std::ostream& operator<<(std::ostream& os, NodeEntry const& e) { + os << "stats: " << e.stats << ", " + << "root_gain: " << e.root_gain << ", " + << "weight: " << e.weight << ", " + << "best: " << e.best << std::endl; + return os; } + // constructor + NodeEntry() = default; +}; - void LoadConfig(Json const& in) override { - auto const& config = get(in); - FromJson(config.at("train_param"), &this->param_); - FromJson(config.at("colmaker_train_param"), &this->colmaker_param_); - } - void SaveConfig(Json* p_out) const override { - auto& out = *p_out; - out["train_param"] = ToJson(param_); - out["colmaker_train_param"] = ToJson(colmaker_param_); - } +bool IsFresh(bst_node_t nid) { + return nid >= 0; +} - char const* Name() const override { - return "grow_colmaker"; +// actual builder that runs the algorithm +class Builder { + public: + // constructor + explicit Builder(const TrainParam ¶m, + const ColMakerTrainParam &colmaker_train_param, + std::unique_ptr spliteval, + FeatureInteractionConstraintHost _interaction_constraints, + const std::vector &column_densities, common::Monitor* monitor) + : param_(param), colmaker_train_param_{colmaker_train_param}, + nthread_(omp_get_max_threads()), + spliteval_(std::move(spliteval)), interaction_constraints_{std::move( + _interaction_constraints)}, + column_densities_(column_densities), moniter_{monitor} { } + // update one tree, growing + virtual void Update(const std::vector& gpair, + DMatrix* p_fmat, + RegTree* p_tree) { + std::vector newnodes; + this->InitData(gpair, *p_fmat, *p_tree); + this->InitNewNode(qexpand_, gpair, *p_fmat, *p_tree); - void LazyGetColumnDensity(DMatrix *dmat) { - // Finds densities if we don't already have them - if (column_densities_.empty()) { - std::vector column_size(dmat->Info().num_col_); - for (const auto &batch : dmat->GetBatches()) { - for (auto i = 0u; i < batch.Size(); i++) { - column_size[i] += batch[i].size(); + for (int depth = 0; depth < param_.max_depth; ++depth) { + this->FindSplit(depth, qexpand_, gpair, p_fmat, p_tree); + this->ResetPosition(qexpand_, p_fmat, *p_tree); + this->UpdateQueueExpand(*p_tree, qexpand_, &newnodes); + this->InitNewNode(newnodes, gpair, *p_fmat, *p_tree); + for (auto nid : qexpand_) { + if ((*p_tree)[nid].IsLeaf()) { + continue; } + int cleft = (*p_tree)[nid].LeftChild(); + int cright = (*p_tree)[nid].RightChild(); + spliteval_->AddSplit(nid, + cleft, + cright, + snode_[nid].best.SplitIndex(), + snode_[cleft].weight, + snode_[cright].weight); + interaction_constraints_.Split(nid, snode_[nid].best.SplitIndex(), cleft, cright); } - column_densities_.resize(column_size.size()); - for (auto i = 0u; i < column_densities_.size(); i++) { - size_t nmiss = dmat->Info().num_row_ - column_size[i]; - column_densities_[i] = - 1.0f - (static_cast(nmiss)) / dmat->Info().num_row_; + qexpand_ = newnodes; + // if nothing left to be expand, break + if (qexpand_.size() == 0) { + break; } } - } - void Update(HostDeviceVector *gpair, - DMatrix* dmat, - const std::vector &trees) override { - if (rabit::IsDistributed()) { - LOG(FATAL) << "Updater `grow_colmaker` or `exact` tree method doesn't " - "support distributed training."; + // set all the rest expanding nodes to leaf + for (const int nid : qexpand_) { + // This unmarks the "fresh" node. + (*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate); } - this->LazyGetColumnDensity(dmat); - // rescale learning rate according to size of trees - float lr = param_.learning_rate; - param_.learning_rate = lr / trees.size(); - interaction_constraints_.Configure(param_, dmat->Info().num_row_); - // build tree - for (auto tree : trees) { - Builder builder( - param_, - colmaker_param_, - std::unique_ptr(spliteval_->GetHostClone()), - interaction_constraints_, column_densities_); - builder.Update(gpair->ConstHostVector(), dmat, tree); + // remember auxiliary statistics in the tree node + for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) { + p_tree->Stat(nid).loss_chg = snode_[nid].best.loss_chg; + p_tree->Stat(nid).base_weight = snode_[nid].weight; + p_tree->Stat(nid).sum_hess = static_cast(snode_[nid].stats.sum_hess); } - param_.learning_rate = lr; } protected: - // training parameter - TrainParam param_; - ColMakerTrainParam colmaker_param_; - // SplitEvaluator that will be cloned for each Builder - std::unique_ptr spliteval_; - std::vector column_densities_; - - FeatureInteractionConstraintHost interaction_constraints_; - // data structure - /*! \brief per thread x per node entry to store tmp data */ - struct ThreadEntry { - /*! \brief statistics of data */ - GradStats stats; - /*! \brief last feature value scanned */ - bst_float last_fvalue { 0 }; - /*! \brief current best solution */ - SplitEntry best; - // constructor - ThreadEntry() = default; - }; - struct NodeEntry { - /*! \brief statics for node entry */ - GradStats stats; - /*! \brief loss of this node, without split */ - bst_float root_gain { 0.0f }; - /*! \brief weight calculated related to current data */ - bst_float weight { 0.0f }; - /*! \brief current best solution */ - SplitEntry best; - // constructor - NodeEntry() = default; - }; - // actual builder that runs the algorithm - class Builder { - public: - // constructor - explicit Builder(const TrainParam& param, - const ColMakerTrainParam& colmaker_train_param, - std::unique_ptr spliteval, - FeatureInteractionConstraintHost _interaction_constraints, - const std::vector &column_densities) - : param_(param), colmaker_train_param_{colmaker_train_param}, - nthread_(omp_get_max_threads()), - spliteval_(std::move(spliteval)), - interaction_constraints_{std::move(_interaction_constraints)}, - column_densities_(column_densities) {} - // update one tree, growing - virtual void Update(const std::vector& gpair, - DMatrix* p_fmat, - RegTree* p_tree) { - std::vector newnodes; - this->InitData(gpair, *p_fmat, *p_tree); - this->InitNewNode(qexpand_, gpair, *p_fmat, *p_tree); - for (int depth = 0; depth < param_.max_depth; ++depth) { - this->FindSplit(depth, qexpand_, gpair, p_fmat, p_tree); - this->ResetPosition(qexpand_, p_fmat, *p_tree); - this->UpdateQueueExpand(*p_tree, qexpand_, &newnodes); - this->InitNewNode(newnodes, gpair, *p_fmat, *p_tree); - for (auto nid : qexpand_) { - if ((*p_tree)[nid].IsLeaf()) { - continue; - } - int cleft = (*p_tree)[nid].LeftChild(); - int cright = (*p_tree)[nid].RightChild(); - spliteval_->AddSplit(nid, - cleft, - cright, - snode_[nid].best.SplitIndex(), - snode_[cleft].weight, - snode_[cright].weight); - interaction_constraints_.Split(nid, snode_[nid].best.SplitIndex(), cleft, cright); - } - qexpand_ = newnodes; - // if nothing left to be expand, break - if (qexpand_.size() == 0) break; - } - // set all the rest expanding nodes to leaf - for (const int nid : qexpand_) { - (*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate); - } - // remember auxiliary statistics in the tree node - for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) { - p_tree->Stat(nid).loss_chg = snode_[nid].best.loss_chg; - p_tree->Stat(nid).base_weight = snode_[nid].weight; - p_tree->Stat(nid).sum_hess = static_cast(snode_[nid].stats.sum_hess); + // initialize temp data structure + inline void InitData(const std::vector& gpair, + const DMatrix& fmat, + const RegTree& tree) { + { + // setup position + position_.resize(gpair.size()); + CHECK_EQ(fmat.Info().num_row_, position_.size()); + std::fill(position_.begin(), position_.end(), 0); + // mark delete for the deleted datas + for (size_t ridx = 0; ridx < position_.size(); ++ridx) { + if (gpair[ridx].GetHess() < 0.0f) position_[ridx] = ~position_[ridx]; } - } - - protected: - // initialize temp data structure - inline void InitData(const std::vector& gpair, - const DMatrix& fmat, - const RegTree& tree) { - { - // setup position - position_.resize(gpair.size()); - CHECK_EQ(fmat.Info().num_row_, position_.size()); - std::fill(position_.begin(), position_.end(), 0); - // mark delete for the deleted datas - for (size_t ridx = 0; ridx < position_.size(); ++ridx) { - if (gpair[ridx].GetHess() < 0.0f) position_[ridx] = ~position_[ridx]; - } - // mark subsample - if (param_.subsample < 1.0f) { - CHECK_EQ(param_.sampling_method, TrainParam::kUniform) + // mark subsample + if (param_.subsample < 1.0f) { + CHECK_EQ(param_.sampling_method, TrainParam::kUniform) << "Only uniform sampling is supported, " << "gradient-based sampling is only support by GPU Hist."; - std::bernoulli_distribution coin_flip(param_.subsample); - auto& rnd = common::GlobalRandom(); - for (size_t ridx = 0; ridx < position_.size(); ++ridx) { - if (gpair[ridx].GetHess() < 0.0f) continue; - if (!coin_flip(rnd)) position_[ridx] = ~position_[ridx]; - } - } - } - { - column_sampler_.Init(fmat.Info().num_col_, param_.colsample_bynode, - param_.colsample_bylevel, param_.colsample_bytree); - } - { - // setup temp space for each thread - // reserve a small space - stemp_.clear(); - stemp_.resize(this->nthread_, std::vector()); - for (auto& i : stemp_) { - i.clear(); i.reserve(256); + std::bernoulli_distribution coin_flip(param_.subsample); + auto& rnd = common::GlobalRandom(); + for (size_t ridx = 0; ridx < position_.size(); ++ridx) { + if (gpair[ridx].GetHess() < 0.0f) continue; + if (!coin_flip(rnd)) position_[ridx] = ~position_[ridx]; } - snode_.reserve(256); - } - { - // expand query - qexpand_.reserve(256); qexpand_.clear(); - qexpand_.push_back(0); } } - /*! - * \brief initialize the base_weight, root_gain, - * and NodeEntry for all the new nodes in qexpand - */ - inline void InitNewNode(const std::vector& qexpand, - const std::vector& gpair, - const DMatrix& fmat, - const RegTree& tree) { - { - // setup statistics space for each tree node - for (auto& i : stemp_) { - i.resize(tree.param.num_nodes, ThreadEntry()); - } - snode_.resize(tree.param.num_nodes, NodeEntry()); - } - const MetaInfo& info = fmat.Info(); - // setup position - const auto ndata = static_cast(info.num_row_); - #pragma omp parallel for schedule(static) - for (bst_omp_uint ridx = 0; ridx < ndata; ++ridx) { - const int tid = omp_get_thread_num(); - if (position_[ridx] < 0) continue; - stemp_[tid][position_[ridx]].stats.Add(gpair[ridx]); + { + column_sampler_.Init(fmat.Info().num_col_, param_.colsample_bynode, + param_.colsample_bylevel, param_.colsample_bytree); + } + { + // setup temp space for each thread + // reserve a small space + stemp_.clear(); + stemp_.resize(this->nthread_, std::vector()); + for (auto& i : stemp_) { + i.clear(); i.reserve(256); } - // sum the per thread statistics together - for (int nid : qexpand) { - GradStats stats; - for (auto& s : stemp_) { - stats.Add(s[nid].stats); - } - // update node statistics - snode_[nid].stats = stats; + snode_.reserve(256); + } + { + // expand query + qexpand_.reserve(256); qexpand_.clear(); + qexpand_.push_back(0); + } + } + /*! + * \brief initialize the base_weight, root_gain, + * and NodeEntry for all the new nodes in qexpand + */ + inline void InitNewNode(const std::vector& qexpand, + const std::vector& gpair, + const DMatrix& fmat, + const RegTree& tree) { + { + // setup statistics space for each tree node + for (auto& i : stemp_) { + i.resize(tree.param.num_nodes, ThreadEntry()); } - // calculating the weights - for (int nid : qexpand) { - bst_uint parentid = tree[nid].Parent(); - snode_[nid].weight = static_cast( - spliteval_->ComputeWeight(parentid, snode_[nid].stats)); - snode_[nid].root_gain = static_cast( - spliteval_->ComputeScore(parentid, snode_[nid].stats, snode_[nid].weight)); + snode_.resize(tree.param.num_nodes, NodeEntry()); + } + const MetaInfo& info = fmat.Info(); + + // setup position + const auto ndata = static_cast(info.num_row_); +#pragma omp parallel for schedule(static) + for (bst_omp_uint ridx = 0; ridx < ndata; ++ridx) { + const int tid = omp_get_thread_num(); + if (position_[ridx] < 0) continue; + stemp_[tid][position_[ridx]].stats.Add(gpair[ridx]); + } + // sum the per thread statistics together + for (int nid : qexpand) { + GradStats stats; + for (auto& s : stemp_) { + stats.Add(s[nid].stats); } + // update node statistics + snode_[nid].stats = stats; } - /*! \brief update queue expand add in new leaves */ - inline void UpdateQueueExpand(const RegTree& tree, - const std::vector &qexpand, - std::vector* p_newnodes) { - p_newnodes->clear(); - for (int nid : qexpand) { - if (!tree[ nid ].IsLeaf()) { - p_newnodes->push_back(tree[nid].LeftChild()); - p_newnodes->push_back(tree[nid].RightChild()); - } + // calculating the weights + for (int nid : qexpand) { + bst_uint parentid = tree[nid].Parent(); + snode_[nid].weight = static_cast( + spliteval_->ComputeWeight(parentid, snode_[nid].stats)); + snode_[nid].root_gain = static_cast( + spliteval_->ComputeScore(parentid, snode_[nid].stats, snode_[nid].weight)); + } + } + /*! \brief update queue expand add in new leaves */ + inline void UpdateQueueExpand(const RegTree& tree, + const std::vector &qexpand, + std::vector* p_newnodes) { + p_newnodes->clear(); + for (int nid : qexpand) { + if (!tree[ nid ].IsLeaf()) { + p_newnodes->push_back(tree[nid].LeftChild()); + p_newnodes->push_back(tree[nid].RightChild()); } } + } - // update enumeration solution - inline void UpdateEnumeration(int nid, GradientPair gstats, - bst_float fvalue, int d_step, bst_uint fid, - GradStats &c, std::vector &temp) const { // NOLINT(*) + // same as EnumerateSplit, with cacheline prefetch optimization + void EnumerateSplit(const Entry *begin, const Entry *end, int d_step, + bst_uint fid, const std::vector &gpair, + std::vector &temp) const { // NOLINT(*) + CHECK(param_.cache_opt) << "Support for `cache_opt' is removed in 1.0.0"; + const std::vector &qexpand = qexpand_; + // clear all the temp statistics + for (auto nid : qexpand) { + temp[nid].stats = GradStats(); + } + // left statistics + GradStats c; + for (const Entry *it = begin; it != end; it += d_step) { + const bst_uint ridx = it->index; + const int nid = position_[ridx]; + if (!IsFresh(nid) || !interaction_constraints_.Query(nid, fid)) { + continue; + } + // start working + const bst_float fvalue = it->fvalue; // get the statistics of nid ThreadEntry &e = temp[nid]; // test if first hit, this is fine, because we set 0 during init if (e.stats.Empty()) { - e.stats.Add(gstats); + e.stats.Add(gpair[ridx]); e.last_fvalue = fvalue; } else { // try to find a split if (fvalue != e.last_fvalue && e.stats.sum_hess >= param_.min_child_weight) { c.SetSubstract(snode_[nid].stats, e.stats); + bst_float loss_chg { 0 }; if (c.sum_hess >= param_.min_child_weight) { - bst_float loss_chg {0}; if (d_step == -1) { loss_chg = static_cast( spliteval_->ComputeSplitScore(nid, fid, c, e.stats) - snode_[nid].root_gain); - bst_float proposed_split = (fvalue + e.last_fvalue) * 0.5f; - if ( proposed_split == fvalue ) { - e.best.Update(loss_chg, fid, e.last_fvalue, - d_step == -1, c, e.stats); - } else { - e.best.Update(loss_chg, fid, proposed_split, - d_step == -1, c, e.stats); - } + e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, + d_step == -1, c, e.stats); } else { loss_chg = static_cast( spliteval_->ComputeSplitScore(nid, fid, e.stats, c) - snode_[nid].root_gain); - bst_float proposed_split = (fvalue + e.last_fvalue) * 0.5f; - if ( proposed_split == fvalue ) { - e.best.Update(loss_chg, fid, e.last_fvalue, + e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, d_step == -1, e.stats, c); - } else { - e.best.Update(loss_chg, fid, proposed_split, - d_step == -1, e.stats, c); - } } } } // update the statistics - e.stats.Add(gstats); + e.stats.Add(gpair[ridx]); e.last_fvalue = fvalue; } } - // same as EnumerateSplit, with cacheline prefetch optimization - void EnumerateSplit(const Entry *begin, - const Entry *end, - int d_step, - bst_uint fid, - const std::vector &gpair, - std::vector &temp) const { // NOLINT(*) - CHECK(param_.cache_opt) << "Support for `cache_opt' is removed in 1.0.0"; - const std::vector &qexpand = qexpand_; - // clear all the temp statistics - for (auto nid : qexpand) { - temp[nid].stats = GradStats(); - } - // left statistics - GradStats c; - // local cache buffer for position and gradient pair - constexpr int kBuffer = 32; - int buf_position[kBuffer] = {}; - GradientPair buf_gpair[kBuffer] = {}; - // aligned ending position - const Entry *align_end; - if (d_step > 0) { - align_end = begin + (end - begin) / kBuffer * kBuffer; - } else { - align_end = begin - (begin - end) / kBuffer * kBuffer; - } - int i; - const Entry *it; - const int align_step = d_step * kBuffer; - // internal cached loop - for (it = begin; it != align_end; it += align_step) { - const Entry *p; - for (i = 0, p = it; i < kBuffer; ++i, p += d_step) { - buf_position[i] = position_[p->index]; - buf_gpair[i] = gpair[p->index]; - } - for (i = 0, p = it; i < kBuffer; ++i, p += d_step) { - const int nid = buf_position[i]; - if (nid < 0 || !interaction_constraints_.Query(nid, fid)) { continue; } - this->UpdateEnumeration(nid, buf_gpair[i], - p->fvalue, d_step, - fid, c, temp); - } - } - - // finish up the ending piece - for (it = align_end, i = 0; it != end; ++i, it += d_step) { - buf_position[i] = position_[it->index]; - buf_gpair[i] = gpair[it->index]; - } - for (it = align_end, i = 0; it != end; ++i, it += d_step) { - const int nid = buf_position[i]; - if (nid < 0 || !interaction_constraints_.Query(nid, fid)) { continue; } - this->UpdateEnumeration(nid, buf_gpair[i], - it->fvalue, d_step, - fid, c, temp); - } - // finish updating all statistics, check if it is possible to include all sum statistics - for (int nid : qexpand) { - ThreadEntry &e = temp[nid]; - c.SetSubstract(snode_[nid].stats, e.stats); - if (e.stats.sum_hess >= param_.min_child_weight && - c.sum_hess >= param_.min_child_weight) { - bst_float loss_chg; - const bst_float gap = std::abs(e.last_fvalue) + kRtEps; - const bst_float delta = d_step == +1 ? gap: -gap; - if (d_step == -1) { - loss_chg = static_cast( - spliteval_->ComputeSplitScore(nid, fid, c, e.stats) - - snode_[nid].root_gain); - e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, c, - e.stats); - } else { - loss_chg = static_cast( - spliteval_->ComputeSplitScore(nid, fid, e.stats, c) - - snode_[nid].root_gain); - e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, - e.stats, c); - } + // finish updating all statistics, check if it is possible to include all + // sum statistics + for (int nid : qexpand) { + ThreadEntry &e = temp[nid]; + c.SetSubstract(snode_[nid].stats, e.stats); + if (e.stats.sum_hess >= param_.min_child_weight && + c.sum_hess >= param_.min_child_weight) { + bst_float loss_chg; + GradStats left_sum; + GradStats right_sum; + if (d_step == -1) { + left_sum = c; + right_sum = e.stats; + } else { + left_sum = e.stats; + right_sum = c; } + loss_chg = static_cast( + spliteval_->ComputeSplitScore(nid, fid, left_sum, right_sum) - + snode_[nid].root_gain); + const bst_float gap = std::abs(e.last_fvalue) + kRtEps; + const bst_float delta = d_step == +1 ? gap : -gap; + e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, + left_sum, right_sum); } } + } - // update the solution candidate - virtual void UpdateSolution(const SparsePage &batch, - const std::vector &feat_set, - const std::vector &gpair, - DMatrix*p_fmat) { - // start enumeration - const auto num_features = static_cast(feat_set.size()); + // update the solution candidate + virtual void UpdateSolution(const SparsePage &batch, + const std::vector &feat_set, + const std::vector &gpair, + DMatrix* p_fmat) { + moniter_->Start(__func__); + // start enumeration + const auto num_features = static_cast(feat_set.size()); #if defined(_OPENMP) - const int batch_size = // NOLINT - std::max(static_cast(num_features / this->nthread_ / 32), 1); + const int batch_size = // NOLINT + std::max(static_cast(num_features / this->nthread_ / 32), 1); #endif // defined(_OPENMP) - { + { #pragma omp parallel for schedule(dynamic, batch_size) - for (bst_omp_uint i = 0; i < num_features; ++i) { - bst_feature_t const fid = feat_set[i]; - int32_t const tid = omp_get_thread_num(); - auto c = batch[fid]; - const bool ind = c.size() != 0 && c[0].fvalue == c[c.size() - 1].fvalue; - if (colmaker_train_param_.NeedForwardSearch( - param_.default_direction, column_densities_[fid], ind)) { - this->EnumerateSplit(c.data(), c.data() + c.size(), +1, - fid, gpair, stemp_[tid]); - } - if (colmaker_train_param_.NeedBackwardSearch(param_.default_direction)) { - this->EnumerateSplit(c.data() + c.size() - 1, c.data() - 1, -1, - fid, gpair, stemp_[tid]); - } + for (bst_omp_uint i = 0; i < num_features; ++i) { + bst_feature_t const fid = feat_set[i]; + int32_t const tid = omp_get_thread_num(); + auto c = batch[fid]; + const bool ind = c.size() != 0 && c[0].fvalue == c[c.size() - 1].fvalue; + if (colmaker_train_param_.NeedForwardSearch( + param_.default_direction, column_densities_[fid], ind)) { + this->EnumerateSplit(c.data(), c.data() + c.size(), +1, + fid, gpair, stemp_[tid]); + } + if (colmaker_train_param_.NeedBackwardSearch(param_.default_direction)) { + this->EnumerateSplit(c.data() + c.size() - 1, c.data() - 1, -1, + fid, gpair, stemp_[tid]); } } } - // find splits at current level, do split per level - inline void FindSplit(int depth, - const std::vector &qexpand, - const std::vector &gpair, - DMatrix *p_fmat, - RegTree *p_tree) { - auto feat_set = column_sampler_.GetFeatureSet(depth); - for (const auto &batch : p_fmat->GetBatches()) { - this->UpdateSolution(batch, feat_set->HostVector(), gpair, p_fmat); - } - // after this each thread's stemp will get the best candidates, aggregate results - this->SyncBestSolution(qexpand); - // get the best result, we can synchronize the solution - for (int nid : qexpand) { - NodeEntry const &e = snode_[nid]; - // now we know the solution in snode[nid], set split - if (e.best.loss_chg > kRtEps) { - bst_float left_leaf_weight = - spliteval_->ComputeWeight(nid, e.best.left_sum) * - param_.learning_rate; - bst_float right_leaf_weight = - spliteval_->ComputeWeight(nid, e.best.right_sum) * - param_.learning_rate; - p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, - e.best.DefaultLeft(), e.weight, left_leaf_weight, - right_leaf_weight, e.best.loss_chg, - e.stats.sum_hess, - e.best.left_sum.GetHess(), e.best.right_sum.GetHess(), - 0); - } else { - (*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate); - } + moniter_->Stop(__func__); + } + // find splits at current level, do split per level + inline void FindSplit(int depth, + const std::vector &qexpand, + const std::vector &gpair, + DMatrix *p_fmat, + RegTree *p_tree) { + auto feat_set = column_sampler_.GetFeatureSet(depth); + for (const auto &batch : p_fmat->GetBatches()) { + this->UpdateSolution(batch, feat_set->HostVector(), gpair, p_fmat); + } + // after this each thread's stemp will get the best candidates, aggregate results + this->SyncBestSolution(qexpand); + // get the best result, we can synchronize the solution + for (int nid : qexpand) { + NodeEntry const &e = snode_[nid]; + // now we know the solution in snode[nid], set split + if (e.best.loss_chg > kRtEps) { + bst_float left_leaf_weight = + spliteval_->ComputeWeight(nid, e.best.left_sum) * + param_.learning_rate; + bst_float right_leaf_weight = + spliteval_->ComputeWeight(nid, e.best.right_sum) * + param_.learning_rate; + // This marks the new nodes to be "fresh" + p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, + e.best.DefaultLeft(), e.weight, left_leaf_weight, + right_leaf_weight, e.best.loss_chg, + e.stats.sum_hess, + e.best.left_sum.GetHess(), e.best.right_sum.GetHess(), + 0); + } else { + // Unmarks "fresh" node. + (*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate); } } - // reset position of each data points after split is created in the tree - inline void ResetPosition(const std::vector &qexpand, - DMatrix* p_fmat, - const RegTree& tree) { - // set the positions in the nondefault - this->SetNonDefaultPosition(qexpand, p_fmat, tree); - // set rest of instances to default position - // set default direct nodes to default - // for leaf nodes that are not fresh, mark then to ~nid, - // so that they are ignored in future statistics collection - const auto ndata = static_cast(p_fmat->Info().num_row_); + } + // reset position of each data points after split is created in the tree + inline void ResetPosition(const std::vector &qexpand, + DMatrix* p_fmat, + const RegTree& tree) { + // set the positions in the nondefault + this->SetNonDefaultPosition(qexpand, p_fmat, tree); + // set rest of instances to default position + // set default direct nodes to default + // for leaf nodes that are not fresh, mark then to ~nid, + // so that they are ignored in future statistics collection + const auto ndata = static_cast(p_fmat->Info().num_row_); #pragma omp parallel for schedule(static) - for (bst_omp_uint ridx = 0; ridx < ndata; ++ridx) { - CHECK_LT(ridx, position_.size()) - << "ridx exceed bound " << "ridx="<< ridx << " pos=" << position_.size(); - const int nid = this->DecodePosition(ridx); - if (tree[nid].IsLeaf()) { - // mark finish when it is not a fresh leaf - if (tree[nid].RightChild() == -1) { - position_[ridx] = ~nid; - } + for (bst_omp_uint ridx = 0; ridx < ndata; ++ridx) { + CHECK_LT(ridx, position_.size()) + << "ridx exceed bound " << "ridx="<< ridx << " pos=" << position_.size(); + const int nid = this->DecodePosition(ridx); + if (tree[nid].IsLeaf()) { + // mark finish when it is not a fresh leaf + if (tree[nid].RightChild() == -1) { + position_[ridx] = ~nid; + } + } else { + // push to default branch + if (tree[nid].DefaultLeft()) { + this->SetEncodePosition(ridx, tree[nid].LeftChild()); } else { - // push to default branch - if (tree[nid].DefaultLeft()) { - this->SetEncodePosition(ridx, tree[nid].LeftChild()); - } else { - this->SetEncodePosition(ridx, tree[nid].RightChild()); - } + this->SetEncodePosition(ridx, tree[nid].RightChild()); } } } - // customization part - // synchronize the best solution of each node - virtual void SyncBestSolution(const std::vector &qexpand) { - for (int nid : qexpand) { - NodeEntry &e = snode_[nid]; - for (int tid = 0; tid < this->nthread_; ++tid) { - e.best.Update(stemp_[tid][nid].best); - } + } + // customization part + // synchronize the best solution of each node + virtual void SyncBestSolution(const std::vector &qexpand) { + for (int nid : qexpand) { + NodeEntry &e = snode_[nid]; + for (int tid = 0; tid < this->nthread_; ++tid) { + e.best.Update(stemp_[tid][nid].best); } } - virtual void SetNonDefaultPosition(const std::vector &qexpand, - DMatrix *p_fmat, - const RegTree &tree) { - // step 1, classify the non-default data into right places - std::vector fsplits; - for (int nid : qexpand) { - if (!tree[nid].IsLeaf()) { - fsplits.push_back(tree[nid].SplitIndex()); - } + } + virtual void SetNonDefaultPosition(const std::vector &qexpand, + DMatrix *p_fmat, + const RegTree &tree) { + // step 1, classify the non-default data into right places + std::vector fsplits; + for (int nid : qexpand) { + if (!tree[nid].IsLeaf()) { + fsplits.push_back(tree[nid].SplitIndex()); } - std::sort(fsplits.begin(), fsplits.end()); - fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin()); - for (const auto &batch : p_fmat->GetBatches()) { - for (auto fid : fsplits) { - auto col = batch[fid]; - const auto ndata = static_cast(col.size()); + } + std::sort(fsplits.begin(), fsplits.end()); + fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin()); + for (const auto &batch : p_fmat->GetBatches()) { + for (auto fid : fsplits) { + auto col = batch[fid]; + const auto ndata = static_cast(col.size()); #pragma omp parallel for schedule(static) - for (bst_omp_uint j = 0; j < ndata; ++j) { - const bst_uint ridx = col[j].index; - const int nid = this->DecodePosition(ridx); - const bst_float fvalue = col[j].fvalue; - // go back to parent, correct those who are not default - if (!tree[nid].IsLeaf() && tree[nid].SplitIndex() == fid) { - if (fvalue < tree[nid].SplitCond()) { - this->SetEncodePosition(ridx, tree[nid].LeftChild()); - } else { - this->SetEncodePosition(ridx, tree[nid].RightChild()); - } + for (bst_omp_uint j = 0; j < ndata; ++j) { + const bst_uint ridx = col[j].index; + const int nid = this->DecodePosition(ridx); + const bst_float fvalue = col[j].fvalue; + // go back to parent, correct those who are not default + if (!tree[nid].IsLeaf() && tree[nid].SplitIndex() == fid) { + if (fvalue < tree[nid].SplitCond()) { + this->SetEncodePosition(ridx, tree[nid].LeftChild()); + } else { + this->SetEncodePosition(ridx, tree[nid].RightChild()); } } } } } - // utils to get/set position, with encoded format - // return decoded position - inline int DecodePosition(bst_uint ridx) const { - const int pid = position_[ridx]; - return pid < 0 ? ~pid : pid; + } + // utils to get/set position, with encoded format + // return decoded position + inline int DecodePosition(bst_uint ridx) const { + const int pid = position_[ridx]; + return pid < 0 ? ~pid : pid; + } + // encode the encoded position value for ridx + inline void SetEncodePosition(bst_uint ridx, int nid) { + if (position_[ridx] < 0) { + position_[ridx] = ~nid; + } else { + position_[ridx] = nid; } - // encode the encoded position value for ridx - inline void SetEncodePosition(bst_uint ridx, int nid) { - if (position_[ridx] < 0) { - position_[ridx] = ~nid; - } else { - position_[ridx] = nid; + } + // --data fields-- + const TrainParam& param_; + const ColMakerTrainParam& colmaker_train_param_; + // number of omp thread used during training + const int nthread_; + common::ColumnSampler column_sampler_; + // Instance Data: current node position in the tree of each instance + std::vector position_; + // PerThread x PerTreeNode: statistics for per thread construction + std::vector< std::vector > stemp_; + /*! \brief TreeNode Data: statistics for each constructed node */ + std::vector snode_; + /*! \brief queue of nodes to be expanded */ + std::vector qexpand_; + // Evaluates splits and computes optimal weights for a given split + std::unique_ptr spliteval_; + + FeatureInteractionConstraintHost interaction_constraints_; + const std::vector &column_densities_; + common::Monitor* moniter_; +}; + +} // anonymous namespace + +/*! \brief column-wise update to construct a tree */ +class ColMaker: public TreeUpdater { + public: + ColMaker() { + moniter_.Init(__func__); + } + void Configure(const Args& args) override { + param_.UpdateAllowUnknown(args); + colmaker_param_.UpdateAllowUnknown(args); + if (!spliteval_) { + spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator)); + } + spliteval_->Init(¶m_); + } + + void LoadConfig(Json const& in) override { + auto const& config = get(in); + FromJson(config.at("train_param"), &this->param_); + FromJson(config.at("colmaker_train_param"), &this->colmaker_param_); + } + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["train_param"] = ToJson(param_); + out["colmaker_train_param"] = ToJson(colmaker_param_); + } + + char const* Name() const override { + return "grow_colmaker_deprecated"; + } + + void LazyGetColumnDensity(DMatrix *dmat) { + // Finds densities if we don't already have them + if (column_densities_.empty()) { + std::vector column_size(dmat->Info().num_col_); + for (const auto &batch : dmat->GetBatches()) { + for (auto i = 0u; i < batch.Size(); i++) { + column_size[i] += batch[i].size(); + } + } + column_densities_.resize(column_size.size()); + for (auto i = 0u; i < column_densities_.size(); i++) { + size_t nmiss = dmat->Info().num_row_ - column_size[i]; + column_densities_[i] = + 1.0f - (static_cast(nmiss)) / dmat->Info().num_row_; } } - // --data fields-- - const TrainParam& param_; - const ColMakerTrainParam& colmaker_train_param_; - // number of omp thread used during training - const int nthread_; - common::ColumnSampler column_sampler_; - // Instance Data: current node position in the tree of each instance - std::vector position_; - // PerThread x PerTreeNode: statistics for per thread construction - std::vector< std::vector > stemp_; - /*! \brief TreeNode Data: statistics for each constructed node */ - std::vector snode_; - /*! \brief queue of nodes to be expanded */ - std::vector qexpand_; - // Evaluates splits and computes optimal weights for a given split - std::unique_ptr spliteval_; + } - FeatureInteractionConstraintHost interaction_constraints_; - const std::vector &column_densities_; - }; + void Update(HostDeviceVector *gpair, + DMatrix* dmat, + const std::vector &trees) override { + if (rabit::IsDistributed()) { + LOG(FATAL) << "Updater `grow_colmaker` or `exact` tree method doesn't " + "support distributed training."; + } + this->LazyGetColumnDensity(dmat); + // rescale learning rate according to size of trees + float lr = param_.learning_rate; + param_.learning_rate = lr / trees.size(); + interaction_constraints_.Configure(param_, dmat->Info().num_row_); + // build tree + for (auto tree : trees) { + Builder builder( + param_, + colmaker_param_, + std::unique_ptr(spliteval_->GetHostClone()), + interaction_constraints_, column_densities_, &moniter_); + builder.Update(gpair->ConstHostVector(), dmat, tree); + } + param_.learning_rate = lr; + } + + protected: + // training parameter + TrainParam param_; + ColMakerTrainParam colmaker_param_; + // SplitEvaluator that will be cloned for each Builder + std::unique_ptr spliteval_; + std::vector column_densities_; + + FeatureInteractionConstraintHost interaction_constraints_; + common::Monitor moniter_; + // data structure + /*! \brief per thread x per node entry to store tmp data */ }; -XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker") +XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker_deprecated") .describe("Grow tree with parallelization over columns.") -.set_body([]() { +.set_body([](GenericParameter const* tparam, LearnerModelParam const* mparam) { return new ColMaker(); }); } // namespace tree diff --git a/src/tree/updater_exact.cc b/src/tree/updater_exact.cc new file mode 100644 index 000000000000..2a0dd8c21dd9 --- /dev/null +++ b/src/tree/updater_exact.cc @@ -0,0 +1,448 @@ +/*! + * Copyright 2020 by XGBoost Contributors + * \file updater_exact.cc + */ +#include +#include +#include +#include +#include + +#include "xgboost/data.h" +#include "xgboost/host_device_vector.h" +#include "xgboost/tree_model.h" +#include "xgboost/tree_updater.h" +#include "xgboost/span.h" +#include "xgboost/base.h" +#include "xgboost/json.h" + +#include "param.h" +#include "updater_exact.h" + +namespace xgboost { + +namespace { +template +void SetSubstract(GradientT const &lhs, GradientT const &rhs, GradientT *out) { + auto* out_gvec = &out->GetGrad()[0]; + auto* out_hvec = &out->GetHess()[0]; + auto const* l_gvec = &lhs.GetGrad()[0]; + auto const* l_hvec = &lhs.GetHess()[0]; + auto const* r_gvec = &rhs.GetGrad()[0]; + auto const* r_hvec = &rhs.GetHess()[0]; + size_t const size = lhs.GetGrad().Size(); + for (size_t i = 0; i < size; i++) { + out_gvec[i] = l_gvec[i] - r_gvec[i]; + out_hvec[i] = l_hvec[i] - r_hvec[i]; + } +} + +template <> +void SetSubstract(SingleGradientPair const &lhs, + SingleGradientPair const &rhs, + SingleGradientPair *out) { + out->GetGrad().vec = lhs.GetGrad().vec - rhs.GetGrad().vec; + out->GetHess().vec = lhs.GetHess().vec - rhs.GetHess().vec; +} +} // anonymous namespace + +namespace tree { + +template +void MultiExact::InitData(DMatrix *data, + common::Span gpairs) { + monitor_.Start(__func__); + this->positions_.clear(); + this->is_splitable_.clear(); + this->nodes_split_.clear(); + this->node_shift_ = 0; + + CHECK_EQ(gpairs.size(), data->Info().num_row_ * mparam_->num_targets); + gpairs_ = std::vector(gpairs.size() / mparam_->num_targets, + MakeGradientPair(mparam_->num_targets)); + CHECK_EQ(gpairs_.size(), data->Info().num_row_); + is_splitable_.resize(param_.MaxNodes(), 1); + + tloc_scans_.resize(omp_get_max_threads()); + for (auto& scan : tloc_scans_) { + scan.resize(param_.MaxNodes()); + } + + auto subsample = param_.subsample; + + // Get a vectorized veiw of gradients. + for (size_t i = 0; i < data->Info().num_row_; ++i) { + size_t beg = i * mparam_->num_targets; + size_t end = beg + mparam_->num_targets; + auto &vec = gpairs_[i]; + for (size_t j = beg; j < end; ++j) { + vec.GetGrad()[j - beg] = gpairs[j].GetGrad(); + vec.GetHess()[j - beg] = gpairs[j].GetHess(); + } + } + + if (subsample != 1.0) { + size_t targets = mparam_->num_targets; + std::bernoulli_distribution flip(subsample); + auto &rnd = common::GlobalRandom(); + std::transform(gpairs_.begin(), gpairs_.end(), gpairs_.begin(), + [&flip, &rnd, targets](GradientT &g) { + if (!flip(rnd)) { + return MakeGradientPair(targets); + } + return g; + }); + } + + sampler_.Init(data->Info().num_col_, param_.colsample_bynode, + param_.colsample_bylevel, param_.colsample_bytree); + + value_constraints_.Init(param_, &monotone_constriants_); + monitor_.Stop(__func__); +} + +template +void MultiExact::InitRoot(DMatrix *data, RegTree *tree) { + monitor_.Start(__func__); + GradientT root_sum {MakeGradientPair(mparam_->num_targets)}; + root_sum = + XGBOOST_PARALLEL_ACCUMULATE(gpairs_.cbegin(), gpairs_.cend(), root_sum, + std::plus{}); + + auto weight = value_constraints_.CalcWeight(root_sum, RegTree::kRoot, param_); + tree->SetLeaf((weight * param_.learning_rate).vec, RegTree::kRoot, + root_sum.GetHess().vec); + + positions_.resize(data->Info().num_row_); + std::fill(positions_.begin(), positions_.end(), RegTree::kRoot); + auto gain = MultiCalcGainGivenWeight(root_sum.GetGrad(), + root_sum.GetHess(), + weight, param_); + SplitEntry root{RegTree::kRoot, root_sum, gain, param_}; + nodes_split_.push_back(root); + + auto p_feature_set = sampler_.GetFeatureSet(0); + this->EvaluateSplit(data, p_feature_set->HostSpan()); + monitor_.Stop(__func__); +} + +template +void MultiExact::EvaluateFeature(bst_feature_t fid, + SparsePage::Inst const &column, + std::vector *p_scans, + std::vector *p_nodes) const { + auto update_node = [fid, this](bool forward, SplitEntry const &scan, + float fcond, float bcond, SplitEntry *node) { + if (forward) { + float loss_chg = value_constraints_.CalcSplitGain( + scan.candidate.left_sum, scan.candidate.right_sum, + node->nidx, fid, param_) - + node->root_gain; + node->candidate.Update(loss_chg, fid, fcond, !forward, + scan.candidate.left_sum, scan.candidate.right_sum); + } else { + float loss_chg = value_constraints_.CalcSplitGain( + scan.candidate.right_sum, scan.candidate.left_sum, + node->nidx, fid, param_) - + node->root_gain; + node->candidate.Update(loss_chg, fid, bcond, !forward, + scan.candidate.right_sum, scan.candidate.left_sum); + } + }; + + auto search_kernel = [fid, this, p_scans, p_nodes, update_node]( + Entry const *const beg, Entry const *const end, + bool const forward) { + auto const inc = forward ? 1 : -1; + auto& node_scans = *p_scans; + size_t targets { this->mparam_->num_targets }; + for (size_t i = node_shift_; i < nodes_split_.size(); ++i) { + auto& scan = node_scans[i]; + scan.candidate.left_sum = MakeGradientPair(targets); + scan.candidate.right_sum = MakeGradientPair(targets); + scan.last_value = std::numeric_limits::quiet_NaN(); + } + + auto &node_splits = *p_nodes; + auto const* p_gpairs = gpairs_.data(); + auto const* p_positions = positions_.data(); + + for (auto it = beg; it != end; it += inc) { + bst_node_t const row_nidx = p_positions[it->index]; + if (is_splitable_[row_nidx] == 0 || + !interaction_constraints_.Query(row_nidx, fid)) { + continue; + } + SplitEntry &scan = node_scans[row_nidx]; + if (AnyLT(scan.candidate.left_sum.GetHess(), param_.min_child_weight) || + it->fvalue == scan.last_value || std::isnan(scan.last_value)) { + scan.Accumulate(p_gpairs[it->index], it->fvalue); + continue; + } + + SplitEntry &node = node_splits[row_nidx]; + SetSubstract(node.parent_sum, scan.candidate.left_sum, + &scan.candidate.right_sum); + + if (AnyLT(scan.candidate.right_sum.GetHess(), param_.min_child_weight)) { + scan.Accumulate(p_gpairs[it->index], it->fvalue); + continue; + } + + float const cond = (it->fvalue + scan.last_value) * 0.5f; + update_node(forward, scan, cond, cond, &node); + scan.Accumulate(p_gpairs[it->index], it->fvalue); + } + + // Try to use all statistic from current column. + size_t const n_nodes = node_splits.size(); + for (size_t n = node_shift_; n < n_nodes; ++n) { + auto &node = node_splits[n]; + auto &scan = node_scans[n]; + SetSubstract(node.parent_sum, scan.candidate.left_sum, + &scan.candidate.right_sum); + if (AnyLT(scan.candidate.left_sum.GetHess(), param_.min_child_weight) || + AnyLT(scan.candidate.right_sum.GetHess(), param_.min_child_weight)) { + continue; + } + bst_float const gap = std::abs(scan.last_value) + kRtEps; + update_node(forward, scan, gap, -gap, &node); + } + }; + + CHECK_LE(p_nodes->size(), param_.MaxNodes()); + if (NeedForward(column, param_)) { + search_kernel(column.data(), column.data() + column.size(), true); + } + + if (NeedBackward(column, param_)) { + search_kernel(column.data() + column.size() - 1, column.data() - 1, false); + } +} + +template +void MultiExact::EvaluateSplit( + DMatrix *data, common::Span features) { + monitor_.Start(__func__); + for (auto const &batch : data->GetBatches()) { + CHECK_EQ(batch.Size(), data->Info().num_col_); + std::vector> tloc_splits(omp_get_max_threads()); + for (auto& s : tloc_splits) { + s = nodes_split_; + } + + dmlc::OMPException omp_handler; +#pragma omp parallel for schedule(dynamic) + for (omp_ulong f = 0; f < features.size(); ++f) { // NOLINT + omp_handler.Run([&]() { + auto fid = features[f]; + auto const &column = batch[fid]; + auto& splits = tloc_splits.at(omp_get_thread_num()); + auto& node_scans = tloc_scans_.at(omp_get_thread_num()); + this->EvaluateFeature(fid, column, &node_scans, &splits); + }); + } + omp_handler.Rethrow(); + + for (auto const& splits : tloc_splits) { + for (size_t i = node_shift_; i < splits.size(); ++i) { + nodes_split_.at(splits.at(i).nidx).candidate.Update(splits.at(i).candidate); + } + } + } + monitor_.Stop(__func__); +} + +template +size_t MultiExact::ExpandTree(RegTree *p_tree, + std::vector *next) { + auto& pending = *next; + auto &tree = *p_tree; + size_t max_node { 0 }; + auto const leaves = tree.GetNumLeaves(); + + for (size_t n = node_shift_; n < nodes_split_.size(); ++n) { + SplitEntry& split = nodes_split_.at(n); + auto weight = + value_constraints_.CalcWeight(split.parent_sum, split.nidx, param_); + if (!split.IsValid(tree.GetDepth(split.nidx), leaves, param_)) { + tree.SetLeaf((weight * param_.learning_rate).vec, split.nidx, + split.parent_sum.GetHess().vec); + CHECK_EQ(is_splitable_[split.nidx], 1); + is_splitable_[split.nidx] = 0; + continue; + } + + CHECK_NE(split.candidate.left_sum.GetGrad().Size(), 0); + auto left_weight = value_constraints_.CalcWeight(split.candidate.left_sum, + split.nidx, param_); + CHECK_NE(split.candidate.right_sum.GetGrad().Size(), 0); + auto right_weight = value_constraints_.CalcWeight(split.candidate.right_sum, + split.nidx, param_); + tree.ExpandNode(split.nidx, + split.candidate.SplitIndex(), + split.candidate.split_value, + split.candidate.DefaultLeft(), + weight.vec, + (left_weight * param_.learning_rate).vec, + (right_weight * param_.learning_rate).vec, + split.candidate.loss_chg, + split.parent_sum.GetHess().vec, + split.candidate.left_sum.GetHess().vec, + split.candidate.right_sum.GetHess().vec); + auto left = tree[split.nidx].LeftChild(); + auto right = tree[split.nidx].RightChild(); + interaction_constraints_.Split(split.nidx, split.candidate.SplitIndex(), left, right); + value_constraints_.Split(split.nidx, left, left_weight, right, right_weight, + split.candidate.SplitIndex()); + + if (SplitEntry::ChildIsValid(tree.GetDepth(left), leaves, param_)) { + auto gain = MultiCalcGainGivenWeight(split.candidate.left_sum.GetGrad(), + split.candidate.left_sum.GetHess(), + left_weight, param_); + SplitEntry s { left, split.candidate.left_sum, gain, param_ }; + pending.push_back(s); + CHECK_EQ(is_splitable_[left], 1); + max_node = std::max(max_node, static_cast(left)); + } else { + is_splitable_[left] = 0; + } + if (SplitEntry::ChildIsValid(tree.GetDepth(right), leaves, param_)) { + auto gain = MultiCalcGainGivenWeight(split.candidate.right_sum.GetGrad(), + split.candidate.right_sum.GetHess(), + right_weight, param_); + SplitEntry s { right, split.candidate.right_sum, gain, param_ }; + pending.push_back(s); + CHECK_EQ(is_splitable_[right], 1); + max_node = std::max(max_node, static_cast(right)); + } else { + is_splitable_[right] = 0; + } + } + return max_node; +} + +template +void MultiExact::ApplySplit(DMatrix *m, RegTree *p_tree) { + monitor_.Start(__func__); + auto &tree = *p_tree; + decltype(nodes_split_) pending; + auto max_node = this->ExpandTree(p_tree, &pending); + + // Fill in non-missing values. + std::vector fsplits; + for (size_t i = node_shift_; i < nodes_split_.size(); ++i) { + auto const& split = nodes_split_.at(i); + if (!tree[split.nidx].IsLeaf()) { + fsplits.push_back(tree[split.nidx].SplitIndex()); + } + } + + node_shift_ = nodes_split_.size(); + std::sort(fsplits.begin(), fsplits.end()); + fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin()); + for (const auto &batch : m->GetBatches()) { + for (auto fid : fsplits) { + auto col = batch[fid]; + const auto ndata = static_cast(col.size()); +#pragma omp parallel for schedule(static) + for (omp_ulong j = 0; j < ndata; ++j) { + const bst_uint ridx = col[j].index; + bst_node_t nidx = positions_[ridx]; + const bst_float fvalue = col[j].fvalue; + if (!tree[nidx].IsLeaf() && tree[nidx].SplitIndex() == fid) { + if (fvalue < tree[nidx].SplitCond()) { + positions_[ridx] = tree[nidx].LeftChild(); + } else { + positions_[ridx] = tree[nidx].RightChild(); + } + } + } + } + } + + // Fill in the missing values. +#pragma omp parallel for schedule(static) + for (omp_ulong r = 0; r < m->Info().num_row_; ++r) { + auto nid = positions_[r]; + if (!tree[nid].IsLeaf()) { + if (tree[nid].DefaultLeft()) { + positions_[r] = tree[nid].LeftChild(); + } else { + positions_[r] = tree[nid].RightChild(); + } + } + } + + if (nodes_split_.size() < max_node + 1) { + nodes_split_.resize(max_node + 1); + CHECK_LE(nodes_split_.size(), param_.MaxNodes()); + } + for (auto split : pending) { + nodes_split_.at(split.nidx) = split; + } + monitor_.Stop(__func__); +} + +template +void MultiExact::UpdateTree(HostDeviceVector *gpair, + DMatrix *data, RegTree *tree) { + this->InitData(data, gpair->ConstHostSpan()); + this->InitRoot(data, tree); + this->ApplySplit(data, tree); + + size_t depth { 1 }; + while (nodes_split_.size() - node_shift_ != 0) { + auto p_feature_set = sampler_.GetFeatureSet(depth); + this->EvaluateSplit(data, p_feature_set->HostSpan()); + this->ApplySplit(data, tree); + depth++; + } +} + +template class MultiExact; +template class MultiExact; + +class MultiExactUpdater : public TreeUpdater { + using SingleTargetExact = MultiExact; + using MultiTargetExact = MultiExact; + + SingleTargetExact single_; + MultiTargetExact multi_; + + public: + explicit MultiExactUpdater(GenericParameter const *tparam, LearnerModelParam const* mparam) + : single_{tparam, mparam}, multi_{tparam, mparam} {} + char const *Name() const override { return single_.Name(); }; + void Configure(const Args &args) override { + single_.Configure(args); + multi_.Configure(args); + } + void LoadConfig(Json const& in) override { + single_.LoadConfig(in); + multi_.LoadConfig(in); + } + void SaveConfig(Json* p_out) const override { + single_.SaveConfig(p_out); + multi_.SaveConfig(p_out); + } + void Update(HostDeviceVector* gpair, + DMatrix* data, + const std::vector& trees) override { + CHECK_NE(trees.size(), 0); + if (trees.front()->Kind() == RegTree::kSingle) { + single_.Update(gpair, data, trees); + } else { + multi_.Update(gpair, data, trees); + } + } +}; + +XGBOOST_REGISTER_TREE_UPDATER(MultiExact, "grow_colmaker") + .describe("Grow tree with parallelization over columns.") + .set_body([](GenericParameter const *tparam, LearnerModelParam const* mparam) { + return new MultiExactUpdater(tparam, mparam); + }); + +} // namespace tree +} // namespace xgboost diff --git a/src/tree/updater_exact.h b/src/tree/updater_exact.h new file mode 100644 index 000000000000..f90e94546590 --- /dev/null +++ b/src/tree/updater_exact.h @@ -0,0 +1,516 @@ +/*! + * Copyright 2020 by XGBoost Contributors + * \file updater_multi_exact.h + * \brief Implementation of exact tree method for training multi-target trees. + */ +#ifndef XGBOOST_TREE_UPDATER_EXACT_H_ +#define XGBOOST_TREE_UPDATER_EXACT_H_ + +#include +#include +#include + +#include "xgboost/tree_updater.h" +#include "xgboost/tree_model.h" +#include "xgboost/json.h" +#include "param.h" +#include "constraints.h" +#include "../common/random.h" +#include "../common/timer.h" + +namespace xgboost { +/* \brief A simple wrapper around `std::vector`. Not much numeric computation is + * needed for XGBoost so we just hand fusing all vector operations without using + * expression template. + */ +template +struct Vector { + using ValueT = Type; + + std::vector vec; + + Vector() = default; + explicit Vector(size_t n, Type v = Type{}) : vec(n, v) {} // NOLINT + + void Resize(size_t s) { return vec.resize(s); } + + size_t Size() const { return vec.size(); } + + template + Vector BinaryOp(Vector const& that, Op op) const { + CHECK_EQ(vec.size(), that.vec.size()); + Vector ret(that.vec.size()); + for (size_t i = 0; i < vec.size(); ++i) { + ret.vec[i] = op(vec[i], that.vec[i]); + } + return ret; + } + template + Vector BinaryScalar(Type const& that, Op op) const { + Vector ret(Size()); + for (size_t i = 0; i < vec.size(); ++i) { + ret[i] = op(vec[i], that); + } + return ret; + } + + Vector operator+(Vector const& that) const { + return BinaryOp(that, std::plus()); + } + Vector operator+(Type const& that) const { + return BinaryScalar(that, std::plus()); + } + Vector& operator+=(Vector const& that) { + size_t size = vec.size(); + for (size_t i = 0; i < size; ++i) { + vec[i] += that[i]; + } + return *this; + } + Vector operator*(Type const& that) const { + return BinaryScalar(that, std::multiplies()); + } + + bool operator==(Vector const& that) const { + if (Size() != that.Size()) { + return false; + } + for (size_t i = 0; i < that.Size(); ++i) { + if (vec[i] != that.vec[i]) { + return false; + } + } + return true; + } + + Type const& operator[](size_t i) const { + return vec[i]; + } + Type& operator[](size_t i) { + return vec[i]; + } + + friend std::ostream& operator<<(std::ostream& os, Vector vec) { + for (size_t i = 0; i < vec.Size(); ++i) { + os << vec[i]; + if (i != vec.Size() - 1) { + os << ", "; + } + } + return os; + } +}; + +/*\brief A specialization over Vector for scalar value. This can be twice + * faster. */ +template +struct ScalarContainer { + using ValueT = Type; + static_assert(std::is_floating_point::value, ""); + ValueT vec; + + public: + constexpr ScalarContainer() = default; + constexpr explicit ScalarContainer(size_t, ValueT v) : vec{v} {} + void Resize(size_t) const { } + constexpr ScalarContainer(ValueT v) : vec{v} {} // NOLINT + ScalarContainer& operator+=(ValueT v) { + vec += v; + return *this; + } + ScalarContainer& operator-=(ValueT v) { + vec -= v; + return *this; + } + constexpr ScalarContainer operator+(ScalarContainer v) const { + return ScalarContainer{vec + v.vec}; + } + constexpr ScalarContainer operator*(ScalarContainer v) const { + return ScalarContainer{vec * v.vec}; + } + + ScalarContainer &operator+=(ScalarContainer v) { + vec += v.vec; + return *this; + } + constexpr ValueT const& operator[](size_t) const { return vec; } + ValueT& operator[](size_t) { return vec; } + + constexpr size_t Size() const { return 1; } + + friend std::ostream& operator<<(std::ostream& os, ScalarContainer const& s) { + os << s.vec; + return os; + } +}; + +using Scalar = ScalarContainer; + +static_assert(std::is_pod::value, ""); +static_assert(std::alignment_of::value == + std::alignment_of::value, + ""); +static_assert(sizeof(Scalar) == sizeof(double), ""); + +using MultiGradientPair = detail::GradientPairInternal>; +using SingleGradientPair = detail::GradientPairInternal; + +template +bool AnyLT(T const& vec, float v) { + // A very time consuming loop. + for (size_t i = 0; i < vec.Size(); ++i) { + if (vec[i] < v) { + return true; + } + } + return false; +} + +template <> +inline bool AnyLT(Scalar const& value, float v) { + return value.vec < v; +} + +template +bool AnyLE(T const& lhs, T const& rhs) { + CHECK_EQ(lhs.Size(), rhs.Size()); + for (size_t i = 0; i < lhs.Size(); ++i) { + if (lhs[i] <= rhs[i]) { + return true; + } + } + return false; +} + +namespace tree { +template struct WeightType { + using Type = typename std::conditional< + std::is_same::value, Vector, + typename std::conditional< + std::is_same::value, + Vector, ScalarContainer>::type + >::type; +}; +template +using WeightT = typename WeightType::Type; + +template > +ReturnT MultiCalcWeight(GradientT const &sum_grad, GradientT const &sum_hess, + TrainParam const &p) { + ReturnT w(sum_grad.Size(), 0); + for (size_t i = 0; i < w.Size(); ++i) { + w[i] = CalcWeight(p, sum_grad[i], sum_hess[i]); + } + return w; +} + +template <> +inline ScalarContainer MultiCalcWeight>( + Scalar const &sum_grad, Scalar const &sum_hess, TrainParam const &p) { + return CalcWeight(p, sum_grad.vec, sum_hess.vec); +} + +template > +ReturnT MultiCalcWeight(GradientT const &sum_gradients, TrainParam const &p) { + return MultiCalcWeight(sum_gradients.GetGrad(), sum_gradients.GetHess(), p); +} + +template +float MultiCalcGainGivenWeight(GradientT const &sum_grad, + GradientT const &sum_hess, + Weight const& weight, + TrainParam const &p) { + float gain { 0 }; + for (size_t i = 0; i < weight.Size(); ++i) { + gain += -weight[i] * ThresholdL1(sum_grad[i], p.reg_alpha); + } + return gain; +} + +template <> +inline float MultiCalcGainGivenWeight>( + Scalar const &sum_grad, Scalar const &sum_hess, + ScalarContainer const &weight, TrainParam const &p) { + return CalcGainGivenWeight(p, sum_grad.vec, sum_hess.vec, + static_cast(weight.vec)); +} + +template +inline GradientT MakeGradientPair(size_t columns); +template <> +inline SingleGradientPair MakeGradientPair(size_t) { + return SingleGradientPair{}; +} +template <> +inline MultiGradientPair MakeGradientPair(size_t columns) { + return MultiGradientPair(Vector(columns, 0), + Vector(columns, 0)); +} + +template +struct ExactSplitEntryContainer { + using Candidate = SplitEntryContainer; + Candidate candidate; + float last_value = {std::numeric_limits::quiet_NaN()}; + + bst_node_t nidx { 0 }; + GradientT parent_sum; + float root_gain { 0 }; + + ExactSplitEntryContainer() = default; + explicit ExactSplitEntryContainer(bst_node_t nidx, + GradientT const &gradient_sum, + float root_gain, + TrainParam const &p) + : nidx{nidx}, parent_sum{gradient_sum}, root_gain{root_gain} { + } + + bool IsValid(int32_t depth, int32_t leaves, TrainParam const& p) { + if (candidate.loss_chg <= kRtEps) { return false; } + if (AnyLT(candidate.left_sum.GetHess(), p.min_child_weight)) { + return false; + } + if (candidate.loss_chg < p.min_split_loss) { + return false; + } + if (p.max_depth > 0 && depth >= p.max_depth) { return false; } + if (p.max_leaves > 0 && leaves >= p.max_leaves) { return false; } + return true; + } + + void Accumulate(GradientT const& g, float value) { + this->candidate.left_sum += g; + this->last_value = value; + } + + static bool ChildIsValid(int32_t depth, int32_t leaves, TrainParam const& p) { + if (p.max_depth > 0 && depth >= p.max_depth) { + return false; + } + if (p.max_leaves > 0 && leaves >= p.max_leaves) { + return false; + } + return true; + } +}; + +template +class MultiValueConstraint { + common::Span lower_; + common::Span upper_; + common::Span monotone_; + size_t targets_; + + public: + struct Storage { + HostDeviceVector lower_storage; + HostDeviceVector upper_storage; + }; + + explicit MultiValueConstraint(size_t targets) : targets_{targets} {} + void Init(TrainParam const& p, Storage* storage) { + monotone_ = {p.monotone_constraints}; + if (!monotone_.empty()) { + storage->lower_storage.Resize(p.MaxNodes(), + -std::numeric_limits::max()); + storage->upper_storage.Resize(p.MaxNodes(), + std::numeric_limits::max()); + lower_ = storage->lower_storage.HostSpan(); + upper_ = storage->upper_storage.HostSpan(); + } + } + + template + Type CalcWeight(GradientT const &sum_gradients, bst_node_t nidx, + TrainParam const &p) const { + auto weight = + MultiCalcWeight(sum_gradients.GetGrad(), sum_gradients.GetHess(), p); + static_assert(std::is_same::value, + ""); + if (monotone_.empty()) { + return weight; + } + CHECK_GT(targets_, 0); + auto lower = lower_.subspan(nidx * targets_, targets_); + auto upper = upper_.subspan(nidx * targets_, targets_); + for (size_t i = 0; i < weight.Size(); ++i) { + auto &w = weight[i]; + if (w < lower[i]) { + w = lower[i]; + } else if (w > upper[i]) { + w = upper[i]; + } + } + return weight; + } + + template + float CalcSplitGain(GradientT const &left, GradientT const &right, + bst_node_t nidx, bst_feature_t feature_id, + TrainParam const &p) const { + auto left_weight = this->CalcWeight(left, nidx, p); + auto right_weight = this->CalcWeight(right, nidx, p); + typename Type::ValueT gain = + MultiCalcGainGivenWeight(left.GetGrad(), left.GetHess(), left_weight, + p) + + MultiCalcGainGivenWeight(right.GetGrad(), right.GetHess(), right_weight, + p); + if (monotone_.empty()) { + return gain; + } + + int32_t constraint = + feature_id >= monotone_.size() ? 0 : monotone_[feature_id]; + if (constraint > 0) { + if (!AnyLE(left_weight, right_weight)) { + gain = -std::numeric_limits::infinity(); + } + } else if (constraint < 0) { + if (!AnyLE(right_weight, left_weight)) { + gain = -std::numeric_limits::infinity(); + } + } + return gain; + } + + void Split(bst_node_t nidx, + bst_node_t left, Type const &left_weight, + bst_node_t right, Type const &right_weight, + bst_feature_t feature_id) { + if (monotone_.empty()) { + return; + } + + Type mid {left_weight.Size(), 0.0}; + for (size_t i = 0; i < mid.Size(); ++i) { + mid[i] = (left_weight[i] + right_weight[i]) / 2; + CHECK(!std::isnan(mid[i])); + } + int32_t constraint = monotone_[feature_id]; + + auto upper_left = upper_.subspan(left * targets_, targets_); + auto lower_left = lower_.subspan(left * targets_, targets_); + + auto upper_right = upper_.subspan(right * targets_, targets_); + auto lower_right = lower_.subspan(right * targets_, targets_); + + auto upper_parent = upper_.subspan(nidx * targets_, targets_); + auto lower_parent = lower_.subspan(nidx * targets_, targets_); + + for (size_t i = 0; i < mid.Size(); ++i) { + upper_left[i] = upper_parent[i]; + upper_right[i] = upper_parent[i]; + + lower_left[i] = lower_parent[i]; + lower_right[i] = lower_parent[i]; + } + + if (constraint < 0) { + for (size_t i = 0; i < mid.Size(); ++i) { + lower_left[i] = mid[i]; + upper_right[i] = mid[i]; + } + } else if (constraint > 0) { + for (size_t i = 0; i < mid.Size(); ++i) { + upper_left[i] = mid[i]; + lower_right[i] = mid[i]; + } + } + } +}; + +template +class MultiExact : public TreeUpdater { + protected: + using SplitEntry = ExactSplitEntryContainer; + // A copy of gradients. + std::vector gpairs_; + // Maps row idx to tree node idx. + std::vector positions_; + // When a node can be further splited. 0 means no, 1 means yes. + // std::vector is not thread safe. + std::vector is_splitable_; + // Splits for current tree. + std::vector nodes_split_; + // Pointer to current layer of nodes. + size_t node_shift_; + // Scan of gradient statistic, used in enumeration. + std::vector> tloc_scans_; + common::ColumnSampler sampler_; + LearnerModelParam const* mparam_; + + bool NeedForward(SparsePage::Inst const &column, TrainParam const& p) const { + return p.default_direction == 2 || + ((column.size() != gpairs_.size()) && // with missing + !(column.size() != 0 && + column[0].fvalue == column[column.size() - 1].fvalue)); + } + bool NeedBackward(SparsePage::Inst const &column, TrainParam const& p) const { + return p.default_direction != 2; + } + + public: + explicit MultiExact(GenericParameter const *runtime, + LearnerModelParam const *mparam) + : mparam_{mparam}, value_constraints_{mparam->num_targets} { + CHECK_NE(mparam->num_targets, 0); + if (runtime) { + tparam_ = runtime; + } + node_shift_ = 0; + monitor_.Init(__func__); + } + void Configure(const Args& args) override { + param_.UpdateAllowUnknown(args); + if (param_.grow_policy != TrainParam::kDepthWise) { + LOG(WARNING) << "Exact tree method supports only depth wise grow policy."; + } + } + void LoadConfig(Json const& in) override { + auto const& config = get(in); + FromJson(config.at("train_param"), &this->param_); + } + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["train_param"] = ToJson(param_); + } + + void InitData(DMatrix* data, common::Span gpairs); + void InitRoot(DMatrix* data, RegTree* tree); + + void EvaluateFeature(bst_feature_t fid, SparsePage::Inst const &column, + std::vector* p_scans, + std::vector *p_nodes) const; + void EvaluateSplit(DMatrix *data, common::Span features); + + size_t ExpandTree(RegTree* p_tree, std::vector* next); + void ApplySplit(DMatrix* m, RegTree* p_tree); + + void UpdateTree(HostDeviceVector* gpair, + DMatrix* data, RegTree* tree); + + public: + void Update(HostDeviceVector* gpair, + DMatrix* data, + const std::vector& trees) override { + interaction_constraints_.Configure(param_, data->Info().num_row_); + float lr = param_.learning_rate; + param_.learning_rate = lr / trees.size(); + for (auto p_tree : trees) { + this->UpdateTree(gpair, data, p_tree); + } + param_.learning_rate = lr; + } + char const* Name() const override { return "grow_colmaker"; }; + + private: + common::Monitor monitor_; + FeatureInteractionConstraintHost interaction_constraints_; + MultiValueConstraint> value_constraints_; + typename MultiValueConstraint>::Storage + monotone_constriants_; + TrainParam param_; +}; +} // namespace tree +} // namespace xgboost +#endif // XGBOOST_TREE_UPDATER_EXACT_H_ diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 25d2645e1032..3ab436f6f264 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -890,7 +890,9 @@ class GPUHistMaker : public TreeUpdater { #if !defined(GTEST_TEST) XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist") .describe("Grow tree with GPU.") - .set_body([]() { return new GPUHistMaker(); }); + .set_body([](GenericParameter const* tparam, LearnerModelParam const* mparam) { + return new GPUHistMaker(); + }); #endif // !defined(GTEST_TEST) } // namespace tree diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index c4fdbe3c0308..1dbac5f7a211 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -743,14 +743,14 @@ class GlobalProposalHistMaker: public CQHistMaker { XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker") .describe("Tree constructor that uses approximate histogram construction.") -.set_body([]() { +.set_body([](GenericParameter const* tparam, LearnerModelParam const* mparam) { return new CQHistMaker(); }); // The updater for approx tree method. XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker") .describe("Tree constructor that uses approximate global of histogram construction.") -.set_body([]() { +.set_body([](GenericParameter const* tparam, LearnerModelParam const* mparam) { return new GlobalProposalHistMaker(); }); } // namespace tree diff --git a/src/tree/updater_prune.cc b/src/tree/updater_prune.cc index 76a8916a0598..2c0a1d07ec18 100644 --- a/src/tree/updater_prune.cc +++ b/src/tree/updater_prune.cc @@ -24,7 +24,7 @@ DMLC_REGISTRY_FILE_TAG(updater_prune); class TreePruner: public TreeUpdater { public: TreePruner() { - syncher_.reset(TreeUpdater::Create("sync", tparam_)); + syncher_.reset(TreeUpdater::Create("sync", tparam_, mparam_)); pruner_monitor_.Init("TreePruner"); } char const* Name() const override { @@ -98,22 +98,23 @@ class TreePruner: public TreeUpdater { npruned = this->TryPruneLeaf(tree, nid, tree.GetDepth(nid), npruned); } } - LOG(INFO) << "tree pruning end, " - << tree.NumExtraNodes() << " extra nodes, " << npruned - << " pruned nodes, max_depth=" << tree.MaxDepth(); + // LOG(INFO) << "tree pruning end, " + // << tree.NumExtraNodes() << " extra nodes, " << npruned + // << " pruned nodes, max_depth=" << tree.MaxDepth(); } private: // synchronizer std::unique_ptr syncher_; // training parameter + LearnerModelParam const* mparam_; TrainParam param_; common::Monitor pruner_monitor_; }; XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune") .describe("Pruner that prune the tree according to statistics.") -.set_body([]() { +.set_body([](GenericParameter const* tparam, LearnerModelParam const* mparam) { return new TreePruner(); }); } // namespace tree diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 30eb01a726ee..7055504c726e 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -40,7 +40,7 @@ DMLC_REGISTER_PARAMETER(CPUHistMakerTrainParam); void QuantileHistMaker::Configure(const Args& args) { // initialize pruner if (!pruner_) { - pruner_.reset(TreeUpdater::Create("prune", tparam_)); + pruner_.reset(TreeUpdater::Create("prune", tparam_, mparam_)); } pruner_->Configure(args); param_.UpdateAllowUnknown(args); @@ -663,7 +663,7 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( } CHECK((*p_last_tree_)[nid].IsLeaf()); } - leaf_value = (*p_last_tree_)[nid].LeafValue(); + leaf_value = p_last_tree_->LeafValue(nid); for (const size_t* it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) { out_preds[*it] += leaf_value; @@ -1360,17 +1360,17 @@ XGBOOST_REGISTER_TREE_UPDATER(FastHistMaker, "grow_fast_histmaker") .describe("(Deprecated, use grow_quantile_histmaker instead.)" " Grow tree using quantized histogram.") .set_body( - []() { + [](GenericParameter const* tparam, LearnerModelParam const* mparam) { LOG(WARNING) << "grow_fast_histmaker is deprecated, " << "use grow_quantile_histmaker instead."; - return new QuantileHistMaker(); + return new QuantileHistMaker(mparam); }); XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker") .describe("Grow tree using quantized histogram.") .set_body( - []() { - return new QuantileHistMaker(); + [](GenericParameter const* tparam, LearnerModelParam const* mparam) { + return new QuantileHistMaker(mparam); }); } // namespace tree diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index d74c16f72ceb..29937ccba813 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -110,7 +110,7 @@ struct CPUHistMakerTrainParam /*! \brief construct a tree using quantized feature values */ class QuantileHistMaker: public TreeUpdater { public: - QuantileHistMaker() { + explicit QuantileHistMaker(LearnerModelParam const* mparam) : mparam_{mparam} { updater_monitor_.Init("QuantileHistMaker"); } void Configure(const Args& args) override; @@ -155,6 +155,7 @@ class QuantileHistMaker: public TreeUpdater { CPUHistMakerTrainParam hist_maker_param_; // training parameter TrainParam param_; + LearnerModelParam const* mparam_; // quantized data matrix GHistIndexMatrix gmat_; // (optional) data matrix with feature grouping diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index d63d88c802d9..5a42b6c1ff04 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -153,7 +153,7 @@ class TreeRefresher: public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh") .describe("Refresher that refreshes the weight and statistics according to data.") -.set_body([]() { +.set_body([](GenericParameter const* tparam, LearnerModelParam const* mparam) { return new TreeRefresher(); }); } // namespace tree diff --git a/src/tree/updater_skmaker.cc b/src/tree/updater_skmaker.cc index 69cb4e58bbed..50914bc832e7 100644 --- a/src/tree/updater_skmaker.cc +++ b/src/tree/updater_skmaker.cc @@ -390,7 +390,7 @@ class SketchMaker: public BaseMaker { XGBOOST_REGISTER_TREE_UPDATER(SketchMaker, "grow_skmaker") .describe("Approximate sketching maker.") -.set_body([]() { +.set_body([](GenericParameter const* tparam, LearnerModelParam const* mparam) { return new SketchMaker(); }); } // namespace tree diff --git a/src/tree/updater_sync.cc b/src/tree/updater_sync.cc index 578bfb83cea9..2842958c0375 100644 --- a/src/tree/updater_sync.cc +++ b/src/tree/updater_sync.cc @@ -53,7 +53,7 @@ class TreeSyncher: public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync") .describe("Syncher that synchronize the tree in all distributed nodes.") -.set_body([]() { +.set_body([](GenericParameter const* tparam, LearnerModelParam const* mparam) { return new TreeSyncher(); }); } // namespace tree diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index 55edb324fee1..e6132f619131 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -1,4 +1,5 @@ #pragma once +#include #include #include #include @@ -148,7 +149,6 @@ inline void ValidateColumn(const HistogramCuts& cuts, int column_idx, const std::vector& sorted_column, const std::vector& sorted_weights, size_t num_bins) { - // Check the endpoints are correct CHECK_GT(sorted_column.size(), 0); EXPECT_LT(cuts.MinValues().at(column_idx), sorted_column.front()); diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index 74002b75aacc..469405724daf 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -39,6 +39,26 @@ TEST(MetaInfo, GetSet) { ASSERT_EQ(info.group_ptr_.size(), 0); } +TEST(MetaInfo, SetLabels) { + size_t constexpr kRows { 128 }; + size_t constexpr kCols { 128 }; + xgboost::HostDeviceVector labels; + std::string arr = + xgboost::RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface( + &labels, true); + xgboost::MetaInfo info; + info.SetInfo("label", arr, xgboost::GenericParameter::kCpuId); + ASSERT_EQ(info.labels_.Size(), labels.Size()); + ASSERT_EQ(info.labels_cols, kCols); + ASSERT_EQ(info.labels_rows, kRows); + for (size_t i = 0; i < kRows; ++i) { + for (size_t j = 0; j < kCols; ++j) { + ASSERT_EQ(labels.HostVector()[i * kCols + j], + info.labels_.HostVector()[i * kCols + j]); + } + } +} + TEST(MetaInfo, SaveLoadBinary) { xgboost::MetaInfo info; uint64_t constexpr kRows { 64 }, kCols { 32 }; diff --git a/tests/cpp/data/test_metainfo.cu b/tests/cpp/data/test_metainfo.cu index ca688dcab3ec..2e605067e997 100644 --- a/tests/cpp/data/test_metainfo.cu +++ b/tests/cpp/data/test_metainfo.cu @@ -45,20 +45,20 @@ TEST(MetaInfo, FromInterface) { std::string str = PrepareData(" expected_group_ptr = {0, 4, 7, 9, 10}; EXPECT_EQ(info.group_ptr_, expected_group_ptr); } @@ -81,7 +81,7 @@ TEST(MetaInfo, Group) { thrust::device_vector d_uint; std::string uint_str = PrepareData(" d_int64; std::string int_str = PrepareData(" d_float; std::string float_str = PrepareData(" d_data; std::string str = PrepareData(" const& obj, std::vector labels, std::vector weights, std::vector out_grad, - std::vector out_hess) { + std::vector out_hess, + size_t rows) { xgboost::MetaInfo info; - info.num_row_ = labels.size(); + if (rows != 0) { + info.num_row_ = rows; + } else { + info.num_row_ = labels.size(); + } info.labels_.HostVector() = labels; info.weights_.HostVector() = weights; @@ -201,11 +206,18 @@ Json RandomDataGenerator::ArrayInterfaceImpl(HostDeviceVector *storage, } std::string RandomDataGenerator::GenerateArrayInterface( - HostDeviceVector *storage) const { + HostDeviceVector *storage, bool wrap_in_column) const { auto array_interface = this->ArrayInterfaceImpl(storage, rows_, cols_); - std::string out; - Json::Dump(array_interface, &out); - return out; + if (wrap_in_column) { + Json out{Array(std::vector{array_interface})}; + std::string str; + Json::Dump(out, &str); + return str; + } else { + std::string str; + Json::Dump(array_interface, &str); + return str; + } } std::pair, std::string> diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 7d59077184ae..1e88113954b5 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -49,12 +49,19 @@ void CreateSimpleTestData(const std::string& filename); void CreateBigTestData(const std::string& filename, size_t n_entries); + +/* + * \brief Automated checking for objective function. + * + * Specify number rows if it doesn't equal to number of labels. + */ void CheckObjFunction(std::unique_ptr const& obj, std::vector preds, std::vector labels, std::vector weights, std::vector out_grad, - std::vector out_hess); + std::vector out_hess, + size_t rows=0); xgboost::Json CheckConfigReloadImpl(xgboost::Configurable* const configurable, std::string name); @@ -219,7 +226,8 @@ class RandomDataGenerator { void GenerateDense(HostDeviceVector* out) const; - std::string GenerateArrayInterface(HostDeviceVector* storage) const; + std::string GenerateArrayInterface(HostDeviceVector *storage, + bool wrap_in_column = false) const; /*! * \brief Generate batches of array interface stored in consecutive memory. diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index b0081448545c..4a2cfaf886b4 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include #include diff --git a/tests/cpp/tree/test_exact.cc b/tests/cpp/tree/test_exact.cc new file mode 100644 index 000000000000..cee1ce1befe0 --- /dev/null +++ b/tests/cpp/tree/test_exact.cc @@ -0,0 +1,196 @@ +/*! + * Copyright 2020 by XGBoost Contributors + */ +#include + +#include +#include + +#include "../helpers.h" +#include "../../../src/tree/updater_exact.h" + +namespace xgboost { +namespace tree { + +class MultiExactTest : public :: testing::Test { + protected: + static bst_row_t constexpr kRows { 64 }; + static bst_feature_t constexpr kCols { 16 }; + static bst_feature_t constexpr kLabels{16}; + + HostDeviceVector gradients_; + std::shared_ptr p_dmat_ {nullptr}; + + void SetUp() override { + gradients_ = GenerateRandomGradients(kRows * kLabels, -1.0f, 1.0f); + auto h_grad = common::Span{gradients_.HostVector()}; + p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatrix(true); + p_dmat_->Info().labels_.Resize(kRows); + + auto &h_labels = p_dmat_->Info().labels_.HostVector(); + h_labels.resize(kRows * kLabels); + SimpleLCG gen; + xgboost::SimpleRealUniformDistribution dist(0, 1); + + for (auto &v : h_labels) { + v = dist(&gen); + } + p_dmat_->Info().labels_cols = kCols; + p_dmat_->Info().labels_rows = kRows; + } + + ~MultiExactTest() override = default; +}; + + +class MultiExactUpdaterForTest : public MultiExact { + public: + explicit MultiExactUpdaterForTest(GenericParameter const *runtime, + LearnerModelParam const& mparam) + : MultiExact{runtime, &mparam} { + this->Configure(Args{}); + } + decltype(gpairs_) &GetGpairs() { return gpairs_; } + decltype(positions_) &GetPositions() { return positions_; } + decltype(nodes_split_) & GetNodesSplit() { return nodes_split_; } +}; + + +TEST_F(MultiExactTest, InitData) { + GenericParameter runtime; + runtime.InitAllowUnknown(Args{}); + runtime.gpu_id = GenericParameter::kCpuId; + LearnerModelParam mparam; + mparam.output_type = OutputType::kMulti; + mparam.num_targets = kLabels; + MultiExactUpdaterForTest updater(&runtime, mparam); + auto h_grad = common::Span{gradients_.HostVector()}; + updater.InitData(p_dmat_.get(), h_grad); + + auto const& gpairs = updater.GetGpairs(); + + ASSERT_EQ(gpairs.size(), p_dmat_->Info().num_row_); + for (size_t i = 0; i < gpairs.size(); ++i) { + auto const& pair = gpairs[i]; + auto const& grad = pair.GetGrad(); + auto const& hess = pair.GetHess(); + ASSERT_EQ(pair.GetGrad().Size(), p_dmat_->Info().labels_cols); + + for (size_t j = 0; j < grad.Size(); ++j) { + ASSERT_EQ(grad[j], h_grad[i * p_dmat_->Info().labels_cols + j].GetGrad()); + ASSERT_EQ(hess[j], h_grad[i * p_dmat_->Info().labels_cols + j].GetHess()); + } + } + + ASSERT_TRUE(updater.GetPositions().empty()); + ASSERT_TRUE(updater.GetNodesSplit().empty()); +} + +TEST_F(MultiExactTest, InitRoot) { + RegTree tree(p_dmat_->Info().num_col_, RegTree::kMulti); + GenericParameter runtime; + runtime.InitAllowUnknown(Args{}); + runtime.gpu_id = GenericParameter::kCpuId; + LearnerModelParam mparam; + mparam.output_type = OutputType::kMulti; + mparam.num_targets = kLabels; + MultiExactUpdaterForTest updater{&runtime, mparam}; + updater.Configure(Args{}); + auto h_grad = common::Span{gradients_.HostVector()}; + updater.InitData(p_dmat_.get(), h_grad); + + updater.InitRoot(p_dmat_.get(), &tree); + auto root_weight = tree.VectorLeafValue(RegTree::kRoot); + ASSERT_EQ(root_weight.size(), p_dmat_->Info().labels_cols); + ASSERT_EQ(updater.GetPositions().size(), p_dmat_->Info().num_row_); + ASSERT_EQ(updater.GetNodesSplit().front().nidx, RegTree::kRoot); +} + +TEST_F(MultiExactTest, EvaluateSplit) { + RegTree tree(p_dmat_->Info().num_col_, RegTree::kMulti); + GenericParameter runtime; + runtime.InitAllowUnknown(Args{}); + runtime.gpu_id = GenericParameter::kCpuId; + LearnerModelParam mparam; + mparam.output_type = OutputType::kMulti; + mparam.num_targets = kLabels; + MultiExactUpdaterForTest updater{&runtime, mparam}; + + for (auto& page : p_dmat_->GetBatches()) { + auto& offset = page.offset.HostVector(); + auto& data = page.data.HostVector(); + // No need for forward search for 0^th column. + data[offset[0]] = data[offset[1] - 1]; + } + + updater.Configure(Args{}); + auto h_grad = common::Span{gradients_.HostVector()}; + updater.InitData(p_dmat_.get(), h_grad); + updater.InitRoot(p_dmat_.get(), &tree); + updater.GetNodesSplit().front().candidate.loss_chg = 0.001; + + std::vector features { 0 }; + updater.EvaluateSplit(p_dmat_.get(), features); + + ASSERT_FALSE(updater.GetNodesSplit().front().candidate.DefaultLeft()); + ASSERT_EQ(updater.GetNodesSplit().front().candidate.SplitIndex(), 0); +} + +TEST_F(MultiExactTest, ApplySplit) { + RegTree tree(p_dmat_->Info().num_col_, RegTree::kMulti); + GenericParameter runtime; + runtime.InitAllowUnknown(Args{}); + runtime.gpu_id = GenericParameter::kCpuId; + LearnerModelParam mparam; + mparam.output_type = OutputType::kMulti; + mparam.num_targets = kLabels; + MultiExactUpdaterForTest updater{&runtime, mparam}; + updater.Configure(Args{}); + auto h_grad = common::Span{gradients_.HostVector()}; + updater.InitData(p_dmat_.get(), h_grad); + updater.InitRoot(p_dmat_.get(), &tree); + ASSERT_EQ(updater.GetNodesSplit().size(), 1); + + // Invent a valid split entry. + auto left_sum = + MakeGradientPair(p_dmat_->Info().labels_cols); + auto right_sum = + MakeGradientPair(p_dmat_->Info().labels_cols); + for (size_t i = 0; i < p_dmat_->Info().labels_cols; ++i) { + left_sum.GetGrad()[i] = 1.3f; + left_sum.GetHess()[i] = 1.0f; + } + float split_value = 0.6; + bst_feature_t split_ind = 0; + updater.GetNodesSplit().front().candidate.loss_chg = 1.0f; + auto success = updater.GetNodesSplit().front().candidate.Update( + 2.0f, split_ind, split_value, true, left_sum, right_sum); + ASSERT_TRUE(success); + ASSERT_TRUE(updater.GetNodesSplit().front().candidate.DefaultLeft()); + updater.ApplySplit(p_dmat_.get(), &tree); + ASSERT_EQ(tree.NumExtraNodes(), 2); + + auto const& pos = updater.GetPositions(); + ASSERT_EQ(pos.size(), p_dmat_->Info().num_row_); + std::set non_missing; + for (auto const& page : p_dmat_->GetBatches()) { + auto column = page[split_ind]; + for (auto const& e : column) { + if (e.fvalue < split_value) { + ASSERT_EQ(pos[e.index], 1); + } else { + ASSERT_EQ(pos[e.index], 2); + } + non_missing.insert(e.index); + } + + for (bst_row_t i = 0; i < pos.size(); ++i) { + if (non_missing.find(i) == non_missing.cend()) { + // dft left + ASSERT_EQ(pos[i], 1); + } + } + } +} +} // namespace tree +} // namespace xgboost diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index fd5c9f43fb2a..aba2e9023295 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -489,7 +489,9 @@ TEST(GpuHist, ExternalMemoryWithSampling) { TEST(GpuHist, ConfigIO) { GenericParameter generic_param(CreateEmptyGenericParam(0)); - std::unique_ptr updater {TreeUpdater::Create("grow_gpu_hist", &generic_param) }; + LearnerModelParam mparam; + std::unique_ptr updater {TreeUpdater::Create("grow_gpu_hist", + &generic_param, &mparam) }; updater->Configure(Args{}); Json j_updater { Object() }; diff --git a/tests/cpp/tree/test_histmaker.cc b/tests/cpp/tree/test_histmaker.cc index e1cb3568d5ef..248f4a25f3bf 100644 --- a/tests/cpp/tree/test_histmaker.cc +++ b/tests/cpp/tree/test_histmaker.cc @@ -33,8 +33,8 @@ TEST(GrowHistMaker, InteractionConstraint) { // With constraints RegTree tree; tree.param.num_feature = kCols; - - std::unique_ptr updater { TreeUpdater::Create("grow_histmaker", ¶m) }; + LearnerModelParam mparam; + std::unique_ptr updater { TreeUpdater::Create("grow_histmaker", ¶m, &mparam) }; updater->Configure(Args{ {"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}}); @@ -50,8 +50,8 @@ TEST(GrowHistMaker, InteractionConstraint) { // Without constraints RegTree tree; tree.param.num_feature = kCols; - - std::unique_ptr updater { TreeUpdater::Create("grow_histmaker", ¶m) }; + LearnerModelParam mparam; + std::unique_ptr updater { TreeUpdater::Create("grow_histmaker", ¶m, &mparam) }; updater->Configure(Args{{"num_feature", std::to_string(kCols)}}); updater->Update(&gradients, p_dmat.get(), {&tree}); diff --git a/tests/cpp/tree/test_prune.cc b/tests/cpp/tree/test_prune.cc index dbe910a8f183..20b91baa2db2 100644 --- a/tests/cpp/tree/test_prune.cc +++ b/tests/cpp/tree/test_prune.cc @@ -38,7 +38,8 @@ TEST(Updater, Prune) { tree.param.UpdateAllowUnknown(cfg); std::vector trees {&tree}; // prepare pruner - std::unique_ptr pruner(TreeUpdater::Create("prune", &lparam)); + LearnerModelParam mparam; + std::unique_ptr pruner(TreeUpdater::Create("prune", &lparam, &mparam)); pruner->Configure(cfg); // loss_chg < min_split_loss; diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 1b6ab89e9992..140e3950b928 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -466,47 +466,46 @@ class QuantileHistMock : public QuantileHistMaker { int static constexpr kNRows = 8, kNCols = 16; std::shared_ptr dmat_; + LearnerModelParam mparam_; const std::vector > cfg_; std::shared_ptr > float_builder_; std::shared_ptr > double_builder_; public: explicit QuantileHistMock( - const std::vector >& args, - const bool single_precision_histogram = false, bool batch = true) : - cfg_{args} { + const std::vector> &args, + const bool single_precision_histogram = false, bool batch = true) + : QuantileHistMaker{&mparam_}, cfg_{args} { QuantileHistMaker::Configure(args); spliteval_->Init(¶m_); dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix(); if (single_precision_histogram) { - float_builder_.reset( - new BuilderMock( - param_, - std::move(pruner_), - std::unique_ptr(spliteval_->GetHostClone()), - int_constraint_, - dmat_.get())); + float_builder_.reset(new BuilderMock( + param_, std::move(pruner_), + std::unique_ptr(spliteval_->GetHostClone()), + int_constraint_, dmat_.get())); if (batch) { float_builder_->SetHistSynchronizer(new BatchHistSynchronizer()); float_builder_->SetHistRowsAdder(new BatchHistRowsAdder()); } else { - float_builder_->SetHistSynchronizer(new DistributedHistSynchronizer()); + float_builder_->SetHistSynchronizer( + new DistributedHistSynchronizer()); float_builder_->SetHistRowsAdder(new DistributedHistRowsAdder()); } } else { - double_builder_.reset( - new BuilderMock( - param_, - std::move(pruner_), - std::unique_ptr(spliteval_->GetHostClone()), - int_constraint_, - dmat_.get())); + double_builder_.reset(new BuilderMock( + param_, std::move(pruner_), + std::unique_ptr(spliteval_->GetHostClone()), + int_constraint_, dmat_.get())); if (batch) { - double_builder_->SetHistSynchronizer(new BatchHistSynchronizer()); + double_builder_->SetHistSynchronizer( + new BatchHistSynchronizer()); double_builder_->SetHistRowsAdder(new BatchHistRowsAdder()); } else { - double_builder_->SetHistSynchronizer(new DistributedHistSynchronizer()); - double_builder_->SetHistRowsAdder(new DistributedHistRowsAdder()); + double_builder_->SetHistSynchronizer( + new DistributedHistSynchronizer()); + double_builder_->SetHistRowsAdder( + new DistributedHistRowsAdder()); } } } diff --git a/tests/cpp/tree/test_refresh.cc b/tests/cpp/tree/test_refresh.cc index 3689940fda35..927a2bd72491 100644 --- a/tests/cpp/tree/test_refresh.cc +++ b/tests/cpp/tree/test_refresh.cc @@ -32,22 +32,23 @@ TEST(Updater, Refresh) { auto lparam = CreateEmptyGenericParam(GPUIDX); tree.param.UpdateAllowUnknown(cfg); std::vector trees {&tree}; - std::unique_ptr refresher(TreeUpdater::Create("refresh", &lparam)); + LearnerModelParam mparam; + std::unique_ptr refresher(TreeUpdater::Create("refresh", &lparam, &mparam)); tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f, /*left_sum=*/0.0f, /*right_sum=*/0.0f); int cleft = tree[0].LeftChild(); int cright = tree[0].RightChild(); - tree.Stat(cleft).base_weight = 1.2; - tree.Stat(cright).base_weight = 1.3; + tree.Stat(cleft).base_weight = 1.2f; + tree.Stat(cright).base_weight = 1.3f; refresher->Configure(cfg); refresher->Update(&gpair, p_dmat.get(), trees); bst_float constexpr kEps = 1e-6; - ASSERT_NEAR(-0.183392, tree[cright].LeafValue(), kEps); - ASSERT_NEAR(-0.224489, tree.Stat(0).loss_chg, kEps); + ASSERT_NEAR(-0.183392f, tree.LeafValue(cright), kEps); + ASSERT_NEAR(-0.224489f, tree.Stat(0).loss_chg, kEps); ASSERT_NEAR(0, tree.Stat(cleft).loss_chg, kEps); ASSERT_NEAR(0, tree.Stat(1).loss_chg, kEps); ASSERT_NEAR(0, tree.Stat(2).loss_chg, kEps); diff --git a/tests/cpp/tree/test_tree_model.cc b/tests/cpp/tree/test_tree_model.cc index dbf2b80a2de4..437ed66382ad 100644 --- a/tests/cpp/tree/test_tree_model.cc +++ b/tests/cpp/tree/test_tree_model.cc @@ -82,7 +82,7 @@ TEST(Tree, Load) { EXPECT_EQ(tree.GetDepth(1), 1); EXPECT_EQ(tree[0].SplitCond(), 0.5f); EXPECT_EQ(tree[0].SplitIndex(), 5); - EXPECT_EQ(tree[1].LeafValue(), 0.1f); + EXPECT_EQ(tree.LeafValue(1), 0.1f); EXPECT_TRUE(tree[1].IsLeaf()); } diff --git a/tests/cpp/tree/test_tree_stat.cc b/tests/cpp/tree/test_tree_stat.cc index eb8a7c5d910c..eba9d7ebb8f5 100644 --- a/tests/cpp/tree/test_tree_stat.cc +++ b/tests/cpp/tree/test_tree_stat.cc @@ -22,8 +22,9 @@ class UpdaterTreeStatTest : public ::testing::Test { void RunTest(std::string updater) { auto tparam = CreateEmptyGenericParam(0); + LearnerModelParam mparam; auto up = std::unique_ptr{ - TreeUpdater::Create(updater, &tparam)}; + TreeUpdater::Create(updater, &tparam, &mparam)}; up->Configure(Args{}); RegTree tree; tree.param.num_feature = kCols; From c88d5c6d8ffd9d759d9563cf4f605cfa1e59d78b Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 3 Jul 2020 04:09:26 +0800 Subject: [PATCH 2/5] Remove dependency on model parameter. --- src/tree/updater_exact.cc | 31 +++++++++++++++++-------------- src/tree/updater_exact.h | 13 +++++-------- tests/cpp/tree/test_exact.cc | 33 ++++++++++----------------------- 3 files changed, 32 insertions(+), 45 deletions(-) diff --git a/src/tree/updater_exact.cc b/src/tree/updater_exact.cc index 2a0dd8c21dd9..63ddf603e37c 100644 --- a/src/tree/updater_exact.cc +++ b/src/tree/updater_exact.cc @@ -50,16 +50,17 @@ namespace tree { template void MultiExact::InitData(DMatrix *data, - common::Span gpairs) { + common::Span gpairs, size_t targets) { monitor_.Start(__func__); + this->targets_ = targets; this->positions_.clear(); this->is_splitable_.clear(); this->nodes_split_.clear(); this->node_shift_ = 0; - CHECK_EQ(gpairs.size(), data->Info().num_row_ * mparam_->num_targets); - gpairs_ = std::vector(gpairs.size() / mparam_->num_targets, - MakeGradientPair(mparam_->num_targets)); + CHECK_EQ(gpairs.size(), data->Info().num_row_ * this->targets_); + gpairs_ = std::vector(gpairs.size() / this->targets_, + MakeGradientPair(this->targets_)); CHECK_EQ(gpairs_.size(), data->Info().num_row_); is_splitable_.resize(param_.MaxNodes(), 1); @@ -72,8 +73,8 @@ void MultiExact::InitData(DMatrix *data, // Get a vectorized veiw of gradients. for (size_t i = 0; i < data->Info().num_row_; ++i) { - size_t beg = i * mparam_->num_targets; - size_t end = beg + mparam_->num_targets; + size_t beg = i * this->targets_; + size_t end = beg + this->targets_; auto &vec = gpairs_[i]; for (size_t j = beg; j < end; ++j) { vec.GetGrad()[j - beg] = gpairs[j].GetGrad(); @@ -82,7 +83,7 @@ void MultiExact::InitData(DMatrix *data, } if (subsample != 1.0) { - size_t targets = mparam_->num_targets; + size_t targets = this->targets_; std::bernoulli_distribution flip(subsample); auto &rnd = common::GlobalRandom(); std::transform(gpairs_.begin(), gpairs_.end(), gpairs_.begin(), @@ -97,14 +98,14 @@ void MultiExact::InitData(DMatrix *data, sampler_.Init(data->Info().num_col_, param_.colsample_bynode, param_.colsample_bylevel, param_.colsample_bytree); - value_constraints_.Init(param_, &monotone_constriants_); + value_constraints_.Init(param_, targets_, &monotone_constriants_); monitor_.Stop(__func__); } template void MultiExact::InitRoot(DMatrix *data, RegTree *tree) { monitor_.Start(__func__); - GradientT root_sum {MakeGradientPair(mparam_->num_targets)}; + GradientT root_sum {MakeGradientPair(tree->LeafSize())}; root_sum = XGBOOST_PARALLEL_ACCUMULATE(gpairs_.cbegin(), gpairs_.cend(), root_sum, std::plus{}); @@ -155,7 +156,7 @@ void MultiExact::EvaluateFeature(bst_feature_t fid, bool const forward) { auto const inc = forward ? 1 : -1; auto& node_scans = *p_scans; - size_t targets { this->mparam_->num_targets }; + size_t targets { this->targets_ }; for (size_t i = node_shift_; i < nodes_split_.size(); ++i) { auto& scan = node_scans[i]; scan.candidate.left_sum = MakeGradientPair(targets); @@ -387,7 +388,9 @@ void MultiExact::ApplySplit(DMatrix *m, RegTree *p_tree) { template void MultiExact::UpdateTree(HostDeviceVector *gpair, DMatrix *data, RegTree *tree) { - this->InitData(data, gpair->ConstHostSpan()); + this->InitData(data, gpair->ConstHostSpan(), tree->LeafSize()); + CHECK_NE(this->targets_, 0); + this->InitRoot(data, tree); this->ApplySplit(data, tree); @@ -411,8 +414,8 @@ class MultiExactUpdater : public TreeUpdater { MultiTargetExact multi_; public: - explicit MultiExactUpdater(GenericParameter const *tparam, LearnerModelParam const* mparam) - : single_{tparam, mparam}, multi_{tparam, mparam} {} + explicit MultiExactUpdater(GenericParameter const *tparam) + : single_{tparam}, multi_{tparam} {} char const *Name() const override { return single_.Name(); }; void Configure(const Args &args) override { single_.Configure(args); @@ -441,7 +444,7 @@ class MultiExactUpdater : public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(MultiExact, "grow_colmaker") .describe("Grow tree with parallelization over columns.") .set_body([](GenericParameter const *tparam, LearnerModelParam const* mparam) { - return new MultiExactUpdater(tparam, mparam); + return new MultiExactUpdater(tparam); }); } // namespace tree diff --git a/src/tree/updater_exact.h b/src/tree/updater_exact.h index f90e94546590..9b91c7425045 100644 --- a/src/tree/updater_exact.h +++ b/src/tree/updater_exact.h @@ -307,8 +307,8 @@ class MultiValueConstraint { HostDeviceVector upper_storage; }; - explicit MultiValueConstraint(size_t targets) : targets_{targets} {} - void Init(TrainParam const& p, Storage* storage) { + void Init(TrainParam const& p, size_t targets, Storage* storage) { + targets_ = targets; monotone_ = {p.monotone_constraints}; if (!monotone_.empty()) { storage->lower_storage.Resize(p.MaxNodes(), @@ -437,7 +437,7 @@ class MultiExact : public TreeUpdater { // Scan of gradient statistic, used in enumeration. std::vector> tloc_scans_; common::ColumnSampler sampler_; - LearnerModelParam const* mparam_; + size_t targets_{0}; bool NeedForward(SparsePage::Inst const &column, TrainParam const& p) const { return p.default_direction == 2 || @@ -450,10 +450,7 @@ class MultiExact : public TreeUpdater { } public: - explicit MultiExact(GenericParameter const *runtime, - LearnerModelParam const *mparam) - : mparam_{mparam}, value_constraints_{mparam->num_targets} { - CHECK_NE(mparam->num_targets, 0); + explicit MultiExact(GenericParameter const *runtime) { if (runtime) { tparam_ = runtime; } @@ -475,7 +472,7 @@ class MultiExact : public TreeUpdater { out["train_param"] = ToJson(param_); } - void InitData(DMatrix* data, common::Span gpairs); + void InitData(DMatrix* data, common::Span gpairs, size_t targets); void InitRoot(DMatrix* data, RegTree* tree); void EvaluateFeature(bst_feature_t fid, SparsePage::Inst const &column, diff --git a/tests/cpp/tree/test_exact.cc b/tests/cpp/tree/test_exact.cc index cee1ce1befe0..61f0eed45dcd 100644 --- a/tests/cpp/tree/test_exact.cc +++ b/tests/cpp/tree/test_exact.cc @@ -45,9 +45,8 @@ class MultiExactTest : public :: testing::Test { class MultiExactUpdaterForTest : public MultiExact { public: - explicit MultiExactUpdaterForTest(GenericParameter const *runtime, - LearnerModelParam const& mparam) - : MultiExact{runtime, &mparam} { + explicit MultiExactUpdaterForTest(GenericParameter const *runtime) + : MultiExact{runtime} { this->Configure(Args{}); } decltype(gpairs_) &GetGpairs() { return gpairs_; } @@ -60,12 +59,9 @@ TEST_F(MultiExactTest, InitData) { GenericParameter runtime; runtime.InitAllowUnknown(Args{}); runtime.gpu_id = GenericParameter::kCpuId; - LearnerModelParam mparam; - mparam.output_type = OutputType::kMulti; - mparam.num_targets = kLabels; - MultiExactUpdaterForTest updater(&runtime, mparam); + MultiExactUpdaterForTest updater(&runtime); auto h_grad = common::Span{gradients_.HostVector()}; - updater.InitData(p_dmat_.get(), h_grad); + updater.InitData(p_dmat_.get(), h_grad, kLabels); auto const& gpairs = updater.GetGpairs(); @@ -91,13 +87,10 @@ TEST_F(MultiExactTest, InitRoot) { GenericParameter runtime; runtime.InitAllowUnknown(Args{}); runtime.gpu_id = GenericParameter::kCpuId; - LearnerModelParam mparam; - mparam.output_type = OutputType::kMulti; - mparam.num_targets = kLabels; - MultiExactUpdaterForTest updater{&runtime, mparam}; + MultiExactUpdaterForTest updater{&runtime}; updater.Configure(Args{}); auto h_grad = common::Span{gradients_.HostVector()}; - updater.InitData(p_dmat_.get(), h_grad); + updater.InitData(p_dmat_.get(), h_grad, kLabels); updater.InitRoot(p_dmat_.get(), &tree); auto root_weight = tree.VectorLeafValue(RegTree::kRoot); @@ -111,10 +104,7 @@ TEST_F(MultiExactTest, EvaluateSplit) { GenericParameter runtime; runtime.InitAllowUnknown(Args{}); runtime.gpu_id = GenericParameter::kCpuId; - LearnerModelParam mparam; - mparam.output_type = OutputType::kMulti; - mparam.num_targets = kLabels; - MultiExactUpdaterForTest updater{&runtime, mparam}; + MultiExactUpdaterForTest updater{&runtime}; for (auto& page : p_dmat_->GetBatches()) { auto& offset = page.offset.HostVector(); @@ -125,7 +115,7 @@ TEST_F(MultiExactTest, EvaluateSplit) { updater.Configure(Args{}); auto h_grad = common::Span{gradients_.HostVector()}; - updater.InitData(p_dmat_.get(), h_grad); + updater.InitData(p_dmat_.get(), h_grad, kLabels); updater.InitRoot(p_dmat_.get(), &tree); updater.GetNodesSplit().front().candidate.loss_chg = 0.001; @@ -141,13 +131,10 @@ TEST_F(MultiExactTest, ApplySplit) { GenericParameter runtime; runtime.InitAllowUnknown(Args{}); runtime.gpu_id = GenericParameter::kCpuId; - LearnerModelParam mparam; - mparam.output_type = OutputType::kMulti; - mparam.num_targets = kLabels; - MultiExactUpdaterForTest updater{&runtime, mparam}; + MultiExactUpdaterForTest updater{&runtime}; updater.Configure(Args{}); auto h_grad = common::Span{gradients_.HostVector()}; - updater.InitData(p_dmat_.get(), h_grad); + updater.InitData(p_dmat_.get(), h_grad, kLabels); updater.InitRoot(p_dmat_.get(), &tree); ASSERT_EQ(updater.GetNodesSplit().size(), 1); From 4e91685942b1945cfb72ccaab9aa4258e3ad5aed Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 3 Jul 2020 04:26:35 +0800 Subject: [PATCH 3/5] Revert changes of passing parameters. --- include/xgboost/tree_model.h | 4 +-- include/xgboost/tree_updater.h | 10 +++---- src/common/quantile.h | 3 +- src/gbm/gbtree.cc | 6 ++-- src/predictor/gpu_predictor.cu | 2 +- src/tree/tree_updater.cc | 9 ++---- src/tree/updater_colmaker.cc | 2 +- src/tree/updater_exact.cc | 6 ++-- src/tree/updater_exact.h | 5 +--- src/tree/updater_gpu_hist.cu | 4 +-- src/tree/updater_histmaker.cc | 4 +-- src/tree/updater_prune.cc | 11 ++++---- src/tree/updater_quantile_hist.cc | 10 +++---- src/tree/updater_quantile_hist.h | 3 +- src/tree/updater_refresh.cc | 2 +- src/tree/updater_skmaker.cc | 2 +- src/tree/updater_sync.cc | 2 +- tests/cpp/common/test_hist_util.h | 2 +- tests/cpp/tree/test_exact.cc | 4 +-- tests/cpp/tree/test_gpu_hist.cu | 4 +-- tests/cpp/tree/test_histmaker.cc | 8 +++--- tests/cpp/tree/test_prune.cc | 3 +- tests/cpp/tree/test_quantile_hist.cc | 41 ++++++++++++++-------------- tests/cpp/tree/test_refresh.cc | 11 ++++---- tests/cpp/tree/test_tree_stat.cc | 3 +- 25 files changed, 70 insertions(+), 91 deletions(-) diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index f95e854f8121..1e0ea46429f1 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -192,7 +192,7 @@ class RegTree : public Model { return cleft_ == kInvalidNodeId; } /*! \return get leaf value of leaf node */ - XGBOOST_DEVICE bst_float SinlgeLeafValue() const { + XGBOOST_DEVICE bst_float SingleLeafValue() const { return (this->info_).leaf_value; } /*! \return get split condition of the node */ @@ -577,7 +577,7 @@ class RegTree : public Model { } float LeafValue(bst_node_t nidx) const { CHECK_EQ(kind_, kSingle); - return (*this)[nidx].SinlgeLeafValue(); + return (*this)[nidx].SingleLeafValue(); } void SetLeaf(std::vector const& leaf, bst_node_t nid, diff --git a/include/xgboost/tree_updater.h b/include/xgboost/tree_updater.h index 09b444250d38..de5c700da050 100644 --- a/include/xgboost/tree_updater.h +++ b/include/xgboost/tree_updater.h @@ -82,18 +82,16 @@ class TreeUpdater : public Configurable { * \param name Name of the tree updater. * \param tparam A global runtime parameter */ - static TreeUpdater* Create(const std::string& name, GenericParameter const* tparam, - LearnerModelParam const* mparam); + static TreeUpdater* Create(const std::string& name, GenericParameter const* tparam); }; /*! * \brief Registry entry for tree updater. */ struct TreeUpdaterReg - : public dmlc::FunctionRegEntryBase< - TreeUpdaterReg, - std::function> {}; + : public dmlc::FunctionRegEntryBase > { +}; /*! * \brief Macro to register tree updater. diff --git a/src/common/quantile.h b/src/common/quantile.h index f4f7c4cd7a7f..c0079ff8ebc8 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -564,8 +564,7 @@ class QuantileSketchTemplate { // check invariant size_t n = (1ULL << nlevel); CHECK(n * limit_size >= maxn) << "invalid init parameter"; - CHECK(nlevel <= std::max(static_cast(1), - static_cast(limit_size * eps))) + CHECK(nlevel <= std::max(static_cast(1), static_cast(limit_size * eps))) << "invalid init parameter"; } diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index cb3d0c23f5ef..6a0b10579cb4 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -253,8 +253,7 @@ void GBTree::InitUpdater(Args const& cfg) { // create new updaters for (const std::string& pstr : ups) { - std::unique_ptr up(TreeUpdater::Create(pstr.c_str(), generic_param_, - model_.learner_model_param)); + std::unique_ptr up(TreeUpdater::Create(pstr.c_str(), generic_param_)); up->Configure(cfg); updaters_.push_back(std::move(up)); } @@ -356,8 +355,7 @@ void GBTree::LoadConfig(Json const& in) { updaters_.clear(); for (auto const& kv : j_updaters) { CHECK(model_.learner_model_param); - std::unique_ptr up(TreeUpdater::Create(kv.first, generic_param_, - model_.learner_model_param)); + std::unique_ptr up(TreeUpdater::Create(kv.first, generic_param_)); up->LoadConfig(kv.second); updaters_.push_back(std::move(up)); } diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 0498ea7f06f1..451ef44e86b8 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -175,7 +175,7 @@ __device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree, } } } - return n.SinlgeLeafValue(); + return n.SingleLeafValue(); } template diff --git a/src/tree/tree_updater.cc b/src/tree/tree_updater.cc index 1b2e7e87eb8c..24ff2ab92968 100644 --- a/src/tree/tree_updater.cc +++ b/src/tree/tree_updater.cc @@ -14,16 +14,13 @@ DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg); namespace xgboost { -TreeUpdater* TreeUpdater::Create(const std::string& name, GenericParameter const* tparam, - LearnerModelParam const* mparam) { +TreeUpdater* TreeUpdater::Create(const std::string& name, GenericParameter const* tparam) { auto *e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name); if (e == nullptr) { LOG(FATAL) << "Unknown tree updater " << name; } - auto p_updater = (e->body)(tparam, mparam); - if (!p_updater->tparam_) { - p_updater->tparam_ = tparam; - } + auto p_updater = (e->body)(); + p_updater->tparam_ = tparam; return p_updater; } diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index f358fc6d738c..cf137311be23 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -603,7 +603,7 @@ class ColMaker: public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker_deprecated") .describe("Grow tree with parallelization over columns.") -.set_body([](GenericParameter const* tparam, LearnerModelParam const* mparam) { +.set_body([]() { return new ColMaker(); }); } // namespace tree diff --git a/src/tree/updater_exact.cc b/src/tree/updater_exact.cc index 63ddf603e37c..1f52c3404c7a 100644 --- a/src/tree/updater_exact.cc +++ b/src/tree/updater_exact.cc @@ -414,8 +414,6 @@ class MultiExactUpdater : public TreeUpdater { MultiTargetExact multi_; public: - explicit MultiExactUpdater(GenericParameter const *tparam) - : single_{tparam}, multi_{tparam} {} char const *Name() const override { return single_.Name(); }; void Configure(const Args &args) override { single_.Configure(args); @@ -443,8 +441,8 @@ class MultiExactUpdater : public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(MultiExact, "grow_colmaker") .describe("Grow tree with parallelization over columns.") - .set_body([](GenericParameter const *tparam, LearnerModelParam const* mparam) { - return new MultiExactUpdater(tparam); + .set_body([]() { + return new MultiExactUpdater(); }); } // namespace tree diff --git a/src/tree/updater_exact.h b/src/tree/updater_exact.h index 9b91c7425045..9a64ac909b88 100644 --- a/src/tree/updater_exact.h +++ b/src/tree/updater_exact.h @@ -450,10 +450,7 @@ class MultiExact : public TreeUpdater { } public: - explicit MultiExact(GenericParameter const *runtime) { - if (runtime) { - tparam_ = runtime; - } + explicit MultiExact() { node_shift_ = 0; monitor_.Init(__func__); } diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 3ab436f6f264..25d2645e1032 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -890,9 +890,7 @@ class GPUHistMaker : public TreeUpdater { #if !defined(GTEST_TEST) XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist") .describe("Grow tree with GPU.") - .set_body([](GenericParameter const* tparam, LearnerModelParam const* mparam) { - return new GPUHistMaker(); - }); + .set_body([]() { return new GPUHistMaker(); }); #endif // !defined(GTEST_TEST) } // namespace tree diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index 1dbac5f7a211..c4fdbe3c0308 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -743,14 +743,14 @@ class GlobalProposalHistMaker: public CQHistMaker { XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker") .describe("Tree constructor that uses approximate histogram construction.") -.set_body([](GenericParameter const* tparam, LearnerModelParam const* mparam) { +.set_body([]() { return new CQHistMaker(); }); // The updater for approx tree method. XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker") .describe("Tree constructor that uses approximate global of histogram construction.") -.set_body([](GenericParameter const* tparam, LearnerModelParam const* mparam) { +.set_body([]() { return new GlobalProposalHistMaker(); }); } // namespace tree diff --git a/src/tree/updater_prune.cc b/src/tree/updater_prune.cc index 2c0a1d07ec18..76a8916a0598 100644 --- a/src/tree/updater_prune.cc +++ b/src/tree/updater_prune.cc @@ -24,7 +24,7 @@ DMLC_REGISTRY_FILE_TAG(updater_prune); class TreePruner: public TreeUpdater { public: TreePruner() { - syncher_.reset(TreeUpdater::Create("sync", tparam_, mparam_)); + syncher_.reset(TreeUpdater::Create("sync", tparam_)); pruner_monitor_.Init("TreePruner"); } char const* Name() const override { @@ -98,23 +98,22 @@ class TreePruner: public TreeUpdater { npruned = this->TryPruneLeaf(tree, nid, tree.GetDepth(nid), npruned); } } - // LOG(INFO) << "tree pruning end, " - // << tree.NumExtraNodes() << " extra nodes, " << npruned - // << " pruned nodes, max_depth=" << tree.MaxDepth(); + LOG(INFO) << "tree pruning end, " + << tree.NumExtraNodes() << " extra nodes, " << npruned + << " pruned nodes, max_depth=" << tree.MaxDepth(); } private: // synchronizer std::unique_ptr syncher_; // training parameter - LearnerModelParam const* mparam_; TrainParam param_; common::Monitor pruner_monitor_; }; XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune") .describe("Pruner that prune the tree according to statistics.") -.set_body([](GenericParameter const* tparam, LearnerModelParam const* mparam) { +.set_body([]() { return new TreePruner(); }); } // namespace tree diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 7055504c726e..be8ccfc12cc2 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -40,7 +40,7 @@ DMLC_REGISTER_PARAMETER(CPUHistMakerTrainParam); void QuantileHistMaker::Configure(const Args& args) { // initialize pruner if (!pruner_) { - pruner_.reset(TreeUpdater::Create("prune", tparam_, mparam_)); + pruner_.reset(TreeUpdater::Create("prune", tparam_)); } pruner_->Configure(args); param_.UpdateAllowUnknown(args); @@ -1360,17 +1360,17 @@ XGBOOST_REGISTER_TREE_UPDATER(FastHistMaker, "grow_fast_histmaker") .describe("(Deprecated, use grow_quantile_histmaker instead.)" " Grow tree using quantized histogram.") .set_body( - [](GenericParameter const* tparam, LearnerModelParam const* mparam) { + []() { LOG(WARNING) << "grow_fast_histmaker is deprecated, " << "use grow_quantile_histmaker instead."; - return new QuantileHistMaker(mparam); + return new QuantileHistMaker(); }); XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker") .describe("Grow tree using quantized histogram.") .set_body( - [](GenericParameter const* tparam, LearnerModelParam const* mparam) { - return new QuantileHistMaker(mparam); + []() { + return new QuantileHistMaker(); }); } // namespace tree diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 29937ccba813..d74c16f72ceb 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -110,7 +110,7 @@ struct CPUHistMakerTrainParam /*! \brief construct a tree using quantized feature values */ class QuantileHistMaker: public TreeUpdater { public: - explicit QuantileHistMaker(LearnerModelParam const* mparam) : mparam_{mparam} { + QuantileHistMaker() { updater_monitor_.Init("QuantileHistMaker"); } void Configure(const Args& args) override; @@ -155,7 +155,6 @@ class QuantileHistMaker: public TreeUpdater { CPUHistMakerTrainParam hist_maker_param_; // training parameter TrainParam param_; - LearnerModelParam const* mparam_; // quantized data matrix GHistIndexMatrix gmat_; // (optional) data matrix with feature grouping diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index 5a42b6c1ff04..d63d88c802d9 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -153,7 +153,7 @@ class TreeRefresher: public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh") .describe("Refresher that refreshes the weight and statistics according to data.") -.set_body([](GenericParameter const* tparam, LearnerModelParam const* mparam) { +.set_body([]() { return new TreeRefresher(); }); } // namespace tree diff --git a/src/tree/updater_skmaker.cc b/src/tree/updater_skmaker.cc index 50914bc832e7..69cb4e58bbed 100644 --- a/src/tree/updater_skmaker.cc +++ b/src/tree/updater_skmaker.cc @@ -390,7 +390,7 @@ class SketchMaker: public BaseMaker { XGBOOST_REGISTER_TREE_UPDATER(SketchMaker, "grow_skmaker") .describe("Approximate sketching maker.") -.set_body([](GenericParameter const* tparam, LearnerModelParam const* mparam) { +.set_body([]() { return new SketchMaker(); }); } // namespace tree diff --git a/src/tree/updater_sync.cc b/src/tree/updater_sync.cc index 2842958c0375..578bfb83cea9 100644 --- a/src/tree/updater_sync.cc +++ b/src/tree/updater_sync.cc @@ -53,7 +53,7 @@ class TreeSyncher: public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync") .describe("Syncher that synchronize the tree in all distributed nodes.") -.set_body([](GenericParameter const* tparam, LearnerModelParam const* mparam) { +.set_body([]() { return new TreeSyncher(); }); } // namespace tree diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index e6132f619131..55edb324fee1 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -1,5 +1,4 @@ #pragma once -#include #include #include #include @@ -149,6 +148,7 @@ inline void ValidateColumn(const HistogramCuts& cuts, int column_idx, const std::vector& sorted_column, const std::vector& sorted_weights, size_t num_bins) { + // Check the endpoints are correct CHECK_GT(sorted_column.size(), 0); EXPECT_LT(cuts.MinValues().at(column_idx), sorted_column.front()); diff --git a/tests/cpp/tree/test_exact.cc b/tests/cpp/tree/test_exact.cc index 61f0eed45dcd..cfd4573b2c98 100644 --- a/tests/cpp/tree/test_exact.cc +++ b/tests/cpp/tree/test_exact.cc @@ -45,8 +45,8 @@ class MultiExactTest : public :: testing::Test { class MultiExactUpdaterForTest : public MultiExact { public: - explicit MultiExactUpdaterForTest(GenericParameter const *runtime) - : MultiExact{runtime} { + explicit MultiExactUpdaterForTest(GenericParameter const *runtime) { + this->tparam_ = runtime; this->Configure(Args{}); } decltype(gpairs_) &GetGpairs() { return gpairs_; } diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index aba2e9023295..fd5c9f43fb2a 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -489,9 +489,7 @@ TEST(GpuHist, ExternalMemoryWithSampling) { TEST(GpuHist, ConfigIO) { GenericParameter generic_param(CreateEmptyGenericParam(0)); - LearnerModelParam mparam; - std::unique_ptr updater {TreeUpdater::Create("grow_gpu_hist", - &generic_param, &mparam) }; + std::unique_ptr updater {TreeUpdater::Create("grow_gpu_hist", &generic_param) }; updater->Configure(Args{}); Json j_updater { Object() }; diff --git a/tests/cpp/tree/test_histmaker.cc b/tests/cpp/tree/test_histmaker.cc index 248f4a25f3bf..e1cb3568d5ef 100644 --- a/tests/cpp/tree/test_histmaker.cc +++ b/tests/cpp/tree/test_histmaker.cc @@ -33,8 +33,8 @@ TEST(GrowHistMaker, InteractionConstraint) { // With constraints RegTree tree; tree.param.num_feature = kCols; - LearnerModelParam mparam; - std::unique_ptr updater { TreeUpdater::Create("grow_histmaker", ¶m, &mparam) }; + + std::unique_ptr updater { TreeUpdater::Create("grow_histmaker", ¶m) }; updater->Configure(Args{ {"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}}); @@ -50,8 +50,8 @@ TEST(GrowHistMaker, InteractionConstraint) { // Without constraints RegTree tree; tree.param.num_feature = kCols; - LearnerModelParam mparam; - std::unique_ptr updater { TreeUpdater::Create("grow_histmaker", ¶m, &mparam) }; + + std::unique_ptr updater { TreeUpdater::Create("grow_histmaker", ¶m) }; updater->Configure(Args{{"num_feature", std::to_string(kCols)}}); updater->Update(&gradients, p_dmat.get(), {&tree}); diff --git a/tests/cpp/tree/test_prune.cc b/tests/cpp/tree/test_prune.cc index 20b91baa2db2..dbe910a8f183 100644 --- a/tests/cpp/tree/test_prune.cc +++ b/tests/cpp/tree/test_prune.cc @@ -38,8 +38,7 @@ TEST(Updater, Prune) { tree.param.UpdateAllowUnknown(cfg); std::vector trees {&tree}; // prepare pruner - LearnerModelParam mparam; - std::unique_ptr pruner(TreeUpdater::Create("prune", &lparam, &mparam)); + std::unique_ptr pruner(TreeUpdater::Create("prune", &lparam)); pruner->Configure(cfg); // loss_chg < min_split_loss; diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 140e3950b928..1b6ab89e9992 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -466,46 +466,47 @@ class QuantileHistMock : public QuantileHistMaker { int static constexpr kNRows = 8, kNCols = 16; std::shared_ptr dmat_; - LearnerModelParam mparam_; const std::vector > cfg_; std::shared_ptr > float_builder_; std::shared_ptr > double_builder_; public: explicit QuantileHistMock( - const std::vector> &args, - const bool single_precision_histogram = false, bool batch = true) - : QuantileHistMaker{&mparam_}, cfg_{args} { + const std::vector >& args, + const bool single_precision_histogram = false, bool batch = true) : + cfg_{args} { QuantileHistMaker::Configure(args); spliteval_->Init(¶m_); dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix(); if (single_precision_histogram) { - float_builder_.reset(new BuilderMock( - param_, std::move(pruner_), - std::unique_ptr(spliteval_->GetHostClone()), - int_constraint_, dmat_.get())); + float_builder_.reset( + new BuilderMock( + param_, + std::move(pruner_), + std::unique_ptr(spliteval_->GetHostClone()), + int_constraint_, + dmat_.get())); if (batch) { float_builder_->SetHistSynchronizer(new BatchHistSynchronizer()); float_builder_->SetHistRowsAdder(new BatchHistRowsAdder()); } else { - float_builder_->SetHistSynchronizer( - new DistributedHistSynchronizer()); + float_builder_->SetHistSynchronizer(new DistributedHistSynchronizer()); float_builder_->SetHistRowsAdder(new DistributedHistRowsAdder()); } } else { - double_builder_.reset(new BuilderMock( - param_, std::move(pruner_), - std::unique_ptr(spliteval_->GetHostClone()), - int_constraint_, dmat_.get())); + double_builder_.reset( + new BuilderMock( + param_, + std::move(pruner_), + std::unique_ptr(spliteval_->GetHostClone()), + int_constraint_, + dmat_.get())); if (batch) { - double_builder_->SetHistSynchronizer( - new BatchHistSynchronizer()); + double_builder_->SetHistSynchronizer(new BatchHistSynchronizer()); double_builder_->SetHistRowsAdder(new BatchHistRowsAdder()); } else { - double_builder_->SetHistSynchronizer( - new DistributedHistSynchronizer()); - double_builder_->SetHistRowsAdder( - new DistributedHistRowsAdder()); + double_builder_->SetHistSynchronizer(new DistributedHistSynchronizer()); + double_builder_->SetHistRowsAdder(new DistributedHistRowsAdder()); } } } diff --git a/tests/cpp/tree/test_refresh.cc b/tests/cpp/tree/test_refresh.cc index 927a2bd72491..de2f39c267c9 100644 --- a/tests/cpp/tree/test_refresh.cc +++ b/tests/cpp/tree/test_refresh.cc @@ -32,23 +32,22 @@ TEST(Updater, Refresh) { auto lparam = CreateEmptyGenericParam(GPUIDX); tree.param.UpdateAllowUnknown(cfg); std::vector trees {&tree}; - LearnerModelParam mparam; - std::unique_ptr refresher(TreeUpdater::Create("refresh", &lparam, &mparam)); + std::unique_ptr refresher(TreeUpdater::Create("refresh", &lparam)); tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f, /*left_sum=*/0.0f, /*right_sum=*/0.0f); int cleft = tree[0].LeftChild(); int cright = tree[0].RightChild(); - tree.Stat(cleft).base_weight = 1.2f; - tree.Stat(cright).base_weight = 1.3f; + tree.Stat(cleft).base_weight = 1.2; + tree.Stat(cright).base_weight = 1.3; refresher->Configure(cfg); refresher->Update(&gpair, p_dmat.get(), trees); bst_float constexpr kEps = 1e-6; - ASSERT_NEAR(-0.183392f, tree.LeafValue(cright), kEps); - ASSERT_NEAR(-0.224489f, tree.Stat(0).loss_chg, kEps); + ASSERT_NEAR(-0.183392, tree[cright].SingleLeafValue(), kEps); + ASSERT_NEAR(-0.224489, tree.Stat(0).loss_chg, kEps); ASSERT_NEAR(0, tree.Stat(cleft).loss_chg, kEps); ASSERT_NEAR(0, tree.Stat(1).loss_chg, kEps); ASSERT_NEAR(0, tree.Stat(2).loss_chg, kEps); diff --git a/tests/cpp/tree/test_tree_stat.cc b/tests/cpp/tree/test_tree_stat.cc index eba9d7ebb8f5..eb8a7c5d910c 100644 --- a/tests/cpp/tree/test_tree_stat.cc +++ b/tests/cpp/tree/test_tree_stat.cc @@ -22,9 +22,8 @@ class UpdaterTreeStatTest : public ::testing::Test { void RunTest(std::string updater) { auto tparam = CreateEmptyGenericParam(0); - LearnerModelParam mparam; auto up = std::unique_ptr{ - TreeUpdater::Create(updater, &tparam, &mparam)}; + TreeUpdater::Create(updater, &tparam)}; up->Configure(Args{}); RegTree tree; tree.param.num_feature = kCols; From 08c6b7709377e6eb7ee5f7e7f0a2d983fd118061 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 3 Jul 2020 04:51:47 +0800 Subject: [PATCH 4/5] Fix validation. --- include/xgboost/data.h | 5 ++++- src/data/data.cc | 8 ++++---- src/gbm/gbtree.cc | 2 +- src/learner.cc | 2 +- tests/cpp/data/test_metainfo.cc | 6 +++--- 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 903af20e44bc..ac593282d28e 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -104,8 +104,11 @@ class MetaInfo { /*! * \brief Validate all metainfo. + * + * \param device GPU ID + * \param targets Number of output targets. */ - void Validate(int32_t device) const; + void Validate(int32_t device, size_t targets) const; MetaInfo Slice(common::Span ridxs) const; /*! diff --git a/src/data/data.cc b/src/data/data.cc index 784bbd30d82c..417ff8a65343 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -383,7 +383,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows) { } } -void MetaInfo::Validate(int32_t device) const { +void MetaInfo::Validate(int32_t device, size_t targets) const { if (group_ptr_.size() != 0 && weights_.Size() != 0) { CHECK_EQ(group_ptr_.size(), weights_.Size() + 1) << "Size of weights must equal to number of groups when ranking " @@ -411,19 +411,19 @@ void MetaInfo::Validate(int32_t device) const { return; } if (labels_.Size() != 0) { - CHECK_EQ(labels_.Size(), num_row_) + CHECK_EQ(labels_.Size(), num_row_ * targets) << "Size of labels must equal to number of rows."; check_device(labels_); return; } if (labels_lower_bound_.Size() != 0) { - CHECK_EQ(labels_lower_bound_.Size(), num_row_) + CHECK_EQ(labels_lower_bound_.Size(), num_row_ * targets) << "Size of label_lower_bound must equal to number of rows."; check_device(labels_lower_bound_); return; } if (labels_upper_bound_.Size() != 0) { - CHECK_EQ(labels_upper_bound_.Size(), num_row_) + CHECK_EQ(labels_upper_bound_.Size(), num_row_ * targets) << "Size of label_upper_bound must equal to number of rows."; check_device(labels_upper_bound_); return; diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 6a0b10579cb4..9eaa5b166906 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -301,7 +301,7 @@ void GBTree::BoostNewTrees(HostDeviceVector* gpair, } } // update the trees - CHECK_EQ(gpair->Size(), p_fmat->Info().num_row_) + CHECK_EQ(gpair->Size(), p_fmat->Info().num_row_ * model_.learner_model_param->num_targets) << "Mismatching size between number of rows from input data and size of " "gradient vector."; for (auto& up : updaters_) { diff --git a/src/learner.cc b/src/learner.cc index 26ff551ac0e5..a0419c0ecbfc 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -1097,7 +1097,7 @@ class LearnerImpl : public LearnerIO { void ValidateDMatrix(DMatrix* p_fmat) const { MetaInfo const& info = p_fmat->Info(); - info.Validate(generic_parameters_.gpu_id); + info.Validate(generic_parameters_.gpu_id, learner_model_param_.num_targets); auto const row_based_split = [this]() { return tparam_.dsplit == DataSplitMode::kRow || diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index 469405724daf..0b965a4bdea6 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -169,18 +169,18 @@ TEST(MetaInfo, Validate) { info.num_col_ = 3; std::vector groups (11); info.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, 11); - EXPECT_THROW(info.Validate(0), dmlc::Error); + EXPECT_THROW(info.Validate(0, 1), dmlc::Error); std::vector labels(info.num_row_ + 1); info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1); - EXPECT_THROW(info.Validate(0), dmlc::Error); + EXPECT_THROW(info.Validate(0, 1), dmlc::Error); #if defined(XGBOOST_USE_CUDA) info.group_ptr_.clear(); labels.resize(info.num_row_); info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_); info.labels_.SetDevice(0); - EXPECT_THROW(info.Validate(1), dmlc::Error); + EXPECT_THROW(info.Validate(1, 1), dmlc::Error); #endif // defined(XGBOOST_USE_CUDA) } From f2424a4ccfcb6c7459efc185ca46885316259210 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 3 Jul 2020 05:44:24 +0800 Subject: [PATCH 5/5] Remove tree kind. --- include/xgboost/model.h | 9 +++++++-- include/xgboost/tree_model.h | 21 ++++++++------------- src/gbm/gbtree.cc | 12 +++++------- src/gbm/gbtree.h | 7 ------- src/gbm/gbtree_model.h | 2 -- src/learner.cc | 19 ++++--------------- src/predictor/cpu_predictor.cc | 10 ++++------ src/tree/tree_model.cc | 14 +++++++------- src/tree/updater_exact.cc | 2 +- tests/cpp/tree/test_exact.cc | 6 +++--- 10 files changed, 39 insertions(+), 63 deletions(-) diff --git a/include/xgboost/model.h b/include/xgboost/model.h index bd24a7fea3dd..12344c66390c 100644 --- a/include/xgboost/model.h +++ b/include/xgboost/model.h @@ -7,6 +7,7 @@ #define XGBOOST_MODEL_H_ #include +#include namespace dmlc { class Stream; @@ -51,6 +52,11 @@ enum class OutputType : int32_t { kMulti }; +inline std::ostream& operator<<(std::ostream& os, OutputType t) { + os << static_cast(t); + return os; +} + /* * \brief Basic Model Parameters, used to describe the booster. */ @@ -61,8 +67,7 @@ struct LearnerModelParam { uint32_t num_feature { 0 }; /* \brief number of classes, if it is multi-class classification */ uint32_t num_output_group { 0 }; - /* \brief number of target variables. */ - uint32_t num_targets { 1 }; + /* \brief Output type of a tree, either single or multi. */ OutputType output_type { OutputType::kSingle }; LearnerModelParam() = default; diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 1e0ea46429f1..3f719f96b238 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -137,14 +137,8 @@ class MultiTargetTreeNodeStat { * This is the data structure used in xgboost's major tree models. */ class RegTree : public Model { - public: - enum TreeKind : int { - kSingle, - kMulti - }; - private: - TreeKind kind_ {kSingle}; + OutputType kind_ {OutputType::kSingle}; public: using SplitCondT = bst_float; @@ -289,8 +283,9 @@ class RegTree : public Model { Info info_; }; - explicit RegTree(bst_feature_t leaf_size = 1, TreeKind kind = kSingle) : - kind_{kind}, leaf_size_{leaf_size}, multi_target_stats_{leaf_size} { + explicit RegTree(bst_feature_t leaf_size = 1, + OutputType kind = OutputType::kSingle) + : kind_{kind}, leaf_size_{leaf_size}, multi_target_stats_{leaf_size} { param.num_nodes = 1; param.num_deleted = 0; nodes_.resize(param.num_nodes); @@ -302,13 +297,13 @@ class RegTree : public Model { if (leaf_size_ != 1) { leaf_values_.resize(leaf_size_); - CHECK_EQ(kind_, kMulti); + CHECK_EQ(static_cast(kind_), static_cast(OutputType::kMulti)); } } /*! * \brief Return tree kind, kSingle or kMulti. */ - TreeKind Kind() const { return kind_; } + OutputType Kind() const { return kind_; } /*! * \brief Return the size of leaf. */ @@ -571,12 +566,12 @@ class RegTree : public Model { }; common::Span VectorLeafValue(bst_node_t nidx) const { - CHECK_EQ(kind_, kMulti); + CHECK_EQ(static_cast(kind_), static_cast(OutputType::kMulti)); auto s = common::Span {leaf_values_}.subspan(nidx * leaf_size_, leaf_size_); return s; } float LeafValue(bst_node_t nidx) const { - CHECK_EQ(kind_, kSingle); + CHECK_EQ(kind_, OutputType::kSingle); return (*this)[nidx].SingleLeafValue(); } diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 9eaa5b166906..dd3ed5798470 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -185,7 +185,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, PredictionCacheEntry* predt) { std::vector > > new_trees; - const int ngroup = model_.learner_model_param->num_output_group; + bst_group_t ngroup = model_.learner_model_param->num_output_group; ConfigureWithKnownData(this->cfg_, p_fmat); monitor_.Start("BoostNewTrees"); CHECK_NE(ngroup, 0); @@ -275,11 +275,8 @@ void GBTree::BoostNewTrees(HostDeviceVector* gpair, "trees."; // create new tree std::unique_ptr ptr; - if (model_.learner_model_param->output_type == OutputType::kSingle) { - ptr.reset(new RegTree(1, RegTree::kSingle)); - } else { - ptr.reset(new RegTree(model_.learner_model_param->num_targets, RegTree::kMulti)); - } + ptr.reset(new RegTree(model_.learner_model_param->num_output_group, + model_.learner_model_param->output_type)); ptr->param.UpdateAllowUnknown(this->cfg_); new_trees.push_back(ptr.get()); ret->push_back(std::move(ptr)); @@ -301,7 +298,8 @@ void GBTree::BoostNewTrees(HostDeviceVector* gpair, } } // update the trees - CHECK_EQ(gpair->Size(), p_fmat->Info().num_row_ * model_.learner_model_param->num_targets) + CHECK_EQ(gpair->Size(), p_fmat->Info().num_row_ * + model_.learner_model_param->num_output_group) << "Mismatching size between number of rows from input data and size of " "gradient vector."; for (auto& up : updaters_) { diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 1a60a4a1df49..c8ca0b5fbf45 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -71,8 +71,6 @@ struct GBTreeTrainParam : public XGBoostParameter { PredictorType predictor; // tree construction method TreeMethod tree_method; - /*! \brief size of leaf vector needed in tree */ - RegTree::TreeKind tree_type; // declare parameters DMLC_DECLARE_PARAMETER(GBTreeTrainParam) { DMLC_DECLARE_FIELD(num_parallel_tree) @@ -104,11 +102,6 @@ struct GBTreeTrainParam : public XGBoostParameter { .add_enum("hist", TreeMethod::kHist) .add_enum("gpu_hist", TreeMethod::kGPUHist) .describe("Choice of tree construction method."); - DMLC_DECLARE_FIELD(tree_type) - .add_enum("single", RegTree::TreeKind::kSingle) - .add_enum("multi", RegTree::TreeKind::kMulti) - .set_default(RegTree::TreeKind::kSingle) - .describe("Type of tree."); } }; diff --git a/src/gbm/gbtree_model.h b/src/gbm/gbtree_model.h index a2b0d4a0a99c..bd0eb57155a1 100644 --- a/src/gbm/gbtree_model.h +++ b/src/gbm/gbtree_model.h @@ -17,8 +17,6 @@ #include "xgboost/parameter.h" #include "xgboost/model.h" -DECLARE_FIELD_ENUM_CLASS(xgboost::RegTree::TreeKind); - namespace xgboost { class Json; diff --git a/src/learner.cc b/src/learner.cc index a0419c0ecbfc..da542bedee2f 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -154,15 +154,12 @@ struct LearnerModelParamLegacy : public dmlc::Parameter LearnerModelParam::LearnerModelParam( LearnerModelParamLegacy const &user_param, float base_margin) : base_score{base_margin}, num_feature{user_param.num_feature}, - num_targets{user_param.num_targets}, output_type{user_param.output_type} + output_type{user_param.output_type} { if (user_param.output_type == OutputType::kSingle) { CHECK(user_param.num_class == 0 || user_param.num_targets == 0); num_output_group = std::max(static_cast(user_param.num_class), user_param.num_targets); - num_targets = 1; - } else { - num_targets = std::max(num_targets, static_cast(user_param.num_class)); } num_output_group = std::max(num_output_group, 1u); } @@ -1097,7 +1094,7 @@ class LearnerImpl : public LearnerIO { void ValidateDMatrix(DMatrix* p_fmat) const { MetaInfo const& info = p_fmat->Info(); - info.Validate(generic_parameters_.gpu_id, learner_model_param_.num_targets); + info.Validate(generic_parameters_.gpu_id, learner_model_param_.num_output_group); auto const row_based_split = [this]() { return tparam_.dsplit == DataSplitMode::kRow || @@ -1107,16 +1104,8 @@ class LearnerImpl : public LearnerIO { CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_) << "Number of columns does not match number of features in booster."; } - - if (learner_model_param_.output_type == OutputType::kSingle) { - CHECK(p_fmat->Info().labels_cols == 1 || - p_fmat->Info().labels_cols == learner_model_param_.num_output_group); - } else { - CHECK(p_fmat->Info().labels_cols == learner_model_param_.num_targets || - p_fmat->Info().labels_cols == learner_model_param_.num_output_group) - << "p_fmat->Info().labels_cols: " << p_fmat->Info().labels_cols << ", " - << "learner_model_param_.num_targets: " << learner_model_param_.num_targets; - } + CHECK(p_fmat->Info().labels_cols == 1 || + p_fmat->Info().labels_cols == learner_model_param_.num_output_group); } private: diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 63d15800da31..2e0a16fe0aba 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -215,7 +215,7 @@ class CPUPredictor : public Predictor { page.data.HostVector(); page.offset.HostVector(); dmlc::OMPException omp_handler; - size_t targets = model.learner_model_param->num_targets; + size_t targets = model.learner_model_param->num_output_group; #pragma omp parallel for for (omp_ulong i = 0; i < page.Size(); ++i) { omp_handler.Run( @@ -253,9 +253,7 @@ class CPUPredictor : public Predictor { HostDeviceVector* out_preds, const gbm::GBTreeModel& model) const { CHECK_NE(model.learner_model_param->num_output_group, 0); - size_t n = std::max(model.learner_model_param->num_targets, - model.learner_model_param->num_output_group) * - info.num_row_; + size_t n = model.learner_model_param->num_output_group * info.num_row_; const auto& base_margin = info.base_margin_.HostVector(); std::vector& out_preds_h = out_preds->HostVector(); // size_t const out_size = info.labels_cols * n; @@ -312,7 +310,7 @@ class CPUPredictor : public Predictor { CHECK_LE(beg_version, end_version); if (beg_version < end_version) { - if (model.trees.front()->Kind() == RegTree::kMulti) { + if (model.trees.front()->Kind() == OutputType::kMulti) { CHECK_EQ(output_groups, 1); for (auto const& page : dmat->GetBatches()) { this->PredictVectorInternal(page, model, &out_preds->HostVector(), @@ -333,7 +331,7 @@ class CPUPredictor : public Predictor { CHECK(out_preds->Size() == output_groups * dmat->Info().num_row_ || out_preds->Size() == - model.learner_model_param->num_targets * dmat->Info().num_row_); + model.learner_model_param->num_output_group * dmat->Info().num_row_); } template diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 99e68f41d88a..d94bf50d5ea7 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -613,7 +613,7 @@ constexpr bst_node_t RegTree::kRoot; std::string RegTree::DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const { - CHECK_EQ(Kind(), kSingle) + CHECK_EQ(Kind(), OutputType::kSingle) << "Dump model is not available for multi-target tree."; std::unique_ptr builder { TreeGenerator::Create(format, fmap, with_stats) @@ -625,7 +625,7 @@ std::string RegTree::DumpModel(const FeatureMap& fmap, } bool RegTree::Equal(const RegTree& b) const { - CHECK_EQ(Kind(), kSingle); + CHECK_EQ(Kind(), OutputType::kSingle); if (NumExtraNodes() != b.NumExtraNodes()) { return false; } @@ -666,7 +666,7 @@ bst_node_t RegTree::GetNumSplitNodes() const { } void RegTree::Load(dmlc::Stream* fi) { - CHECK_NE(Kind(), kMulti) << "Multi-target tree requires JSON serialization format."; + CHECK_NE(Kind(), OutputType::kMulti) << "Multi-target tree requires JSON serialization format."; CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam)); nodes_.resize(param.num_nodes); stats_.resize(param.num_nodes); @@ -685,7 +685,7 @@ void RegTree::Load(dmlc::Stream* fi) { CHECK_EQ(static_cast(deleted_nodes_.size()), param.num_deleted); } void RegTree::Save(dmlc::Stream* fo) const { - CHECK_NE(Kind(), kMulti) << "Model persistent for multi-target tree is not yet implemented."; + CHECK_NE(Kind(), OutputType::kMulti) << "Model persistent for multi-target tree is not yet implemented."; CHECK_EQ(param.num_nodes, static_cast(nodes_.size())); CHECK_EQ(param.num_nodes, static_cast(stats_.size())); fo->Write(¶m, sizeof(TreeParam)); @@ -761,7 +761,7 @@ void RegTree::LoadModel(Json const& in) { } void RegTree::SaveModel(Json* p_out) const { - CHECK_NE(Kind(), kMulti) << "Model persistent for multi-target tree is not yet implemented."; + CHECK_NE(Kind(), OutputType::kMulti) << "Model persistent for multi-target tree is not yet implemented."; auto& out = *p_out; CHECK_EQ(param.num_nodes, static_cast(nodes_.size())); CHECK_EQ(param.num_nodes, static_cast(stats_.size())); @@ -838,7 +838,7 @@ bst_float RegTree::FillNodeMeanValue(int nid) { void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat, bst_float *out_contribs) const { - CHECK_EQ(Kind(), kSingle) << "Contribution is not available for mutli-target tree."; + CHECK_EQ(Kind(), OutputType::kSingle) << "Contribution is not available for mutli-target tree."; CHECK_GT(this->node_mean_values_.size(), 0U); // this follows the idea of http://blog.datadive.net/interpreting-random-forests/ unsigned split_index = 0; @@ -954,7 +954,7 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi, bst_float parent_one_fraction, int parent_feature_index, int condition, unsigned condition_feature, bst_float condition_fraction) const { - CHECK_EQ(Kind(), kSingle) << "Tree shap is not available for mutli-target tree."; + CHECK_EQ(Kind(), OutputType::kSingle) << "Tree shap is not available for mutli-target tree."; const auto node = (*this)[node_index]; // stop if we have no weight coming down to us diff --git a/src/tree/updater_exact.cc b/src/tree/updater_exact.cc index 1f52c3404c7a..d0ae49a9225a 100644 --- a/src/tree/updater_exact.cc +++ b/src/tree/updater_exact.cc @@ -431,7 +431,7 @@ class MultiExactUpdater : public TreeUpdater { DMatrix* data, const std::vector& trees) override { CHECK_NE(trees.size(), 0); - if (trees.front()->Kind() == RegTree::kSingle) { + if (trees.front()->Kind() == OutputType::kSingle) { single_.Update(gpair, data, trees); } else { multi_.Update(gpair, data, trees); diff --git a/tests/cpp/tree/test_exact.cc b/tests/cpp/tree/test_exact.cc index cfd4573b2c98..2fe12490a379 100644 --- a/tests/cpp/tree/test_exact.cc +++ b/tests/cpp/tree/test_exact.cc @@ -83,7 +83,7 @@ TEST_F(MultiExactTest, InitData) { } TEST_F(MultiExactTest, InitRoot) { - RegTree tree(p_dmat_->Info().num_col_, RegTree::kMulti); + RegTree tree(p_dmat_->Info().num_col_, OutputType::kMulti); GenericParameter runtime; runtime.InitAllowUnknown(Args{}); runtime.gpu_id = GenericParameter::kCpuId; @@ -100,7 +100,7 @@ TEST_F(MultiExactTest, InitRoot) { } TEST_F(MultiExactTest, EvaluateSplit) { - RegTree tree(p_dmat_->Info().num_col_, RegTree::kMulti); + RegTree tree(p_dmat_->Info().num_col_, OutputType::kMulti); GenericParameter runtime; runtime.InitAllowUnknown(Args{}); runtime.gpu_id = GenericParameter::kCpuId; @@ -127,7 +127,7 @@ TEST_F(MultiExactTest, EvaluateSplit) { } TEST_F(MultiExactTest, ApplySplit) { - RegTree tree(p_dmat_->Info().num_col_, RegTree::kMulti); + RegTree tree(p_dmat_->Info().num_col_, OutputType::kMulti); GenericParameter runtime; runtime.InitAllowUnknown(Args{}); runtime.gpu_id = GenericParameter::kCpuId;