diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index f9c0a0ffc67b..64685ab9f1d6 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -531,6 +531,7 @@ XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle); * notice. * * \param handle handle to Booster object. + * \param out_len length of output string * \param out_str A valid pointer to array of characters. The characters array is * allocated and managed by XGBoost, while pointer to that array needs to * be managed by caller. diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index 0f61e1c9fe90..666e65652a5b 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -71,7 +71,7 @@ class GradientBooster : public Model, public Configurable { * \brief perform update to the model(boosting) * \param p_fmat feature matrix that provide access to features * \param in_gpair address of the gradient pair statistics of the data - * \param obj The objective function, optional, can be nullptr when use customized version + * \param prediction The output prediction cache entry that needs to be updated. * the booster may change content of gpair */ virtual void DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, @@ -158,7 +158,6 @@ class GradientBooster : public Model, public Configurable { * \param name name of gradient booster * \param generic_param Pointer to runtime parameters * \param learner_model_param pointer to global model parameters - * \param cache_mats The cache data matrix of the Booster. * \return The created booster. */ static GradientBooster* Create( diff --git a/include/xgboost/generic_parameters.h b/include/xgboost/generic_parameters.h index cd5a49ffa754..925af2d3cc43 100644 --- a/include/xgboost/generic_parameters.h +++ b/include/xgboost/generic_parameters.h @@ -1,6 +1,6 @@ /*! * Copyright 2014-2019 by Contributors - * \file learner.cc + * \file generic_parameters.h */ #ifndef XGBOOST_GENERIC_PARAMETERS_H_ #define XGBOOST_GENERIC_PARAMETERS_H_ diff --git a/include/xgboost/json.h b/include/xgboost/json.h index cdd86d135a90..13fde3b5815b 100644 --- a/include/xgboost/json.h +++ b/include/xgboost/json.h @@ -406,7 +406,7 @@ class Json { /*! \brief Index Json object with int, used for Json Array. */ Json& operator[](int ind) const { return (*ptr_)[ind]; } - /*! \Brief Return the reference to stored Json value. */ + /*! \brief Return the reference to stored Json value. */ Value const& GetValue() const & { return *ptr_; } Value const& GetValue() && { return *ptr_; } Value& GetValue() & { return *ptr_; } diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index dbfbc1e34643..661fec0d0cf2 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -87,6 +87,7 @@ class Learner : public Model, public Configurable, public rabit::Serializable { * \param out_preds output vector that stores the prediction * \param ntree_limit limit number of trees used for boosted tree * predictor, when it equals 0, this means we are using all the trees + * \param training Whether the prediction result is used for training * \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor * \param pred_contribs whether to only predict the feature contributions * \param approx_contribs whether to approximate the feature contributions for speed diff --git a/include/xgboost/metric.h b/include/xgboost/metric.h index 36b148b7fe93..db32f719cbab 100644 --- a/include/xgboost/metric.h +++ b/include/xgboost/metric.h @@ -52,8 +52,9 @@ class Metric { /*! * \brief create a metric according to name. * \param name name of the metric. - * name can be in form metric[@]param - * and the name will be matched in the registry. + * name can be in form metric[@]param and the name will be matched in the + * registry. + * \param tparam A global generic parameter * \return the created metric. */ static Metric* Create(const std::string& name, GenericParameter const* tparam); diff --git a/include/xgboost/parameter.h b/include/xgboost/parameter.h index c3314688aec5..3186651eda16 100644 --- a/include/xgboost/parameter.h +++ b/include/xgboost/parameter.h @@ -1,6 +1,6 @@ /*! * Copyright 2018 by Contributors - * \file enum_class_param.h + * \file parameter.h * \brief macro for using C++11 enum class as DMLC parameter * \author Hyunsu Philip Cho */ diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 1b4991833a2d..a872f3437b35 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -134,31 +134,6 @@ class Predictor { uint32_t const ntree_limit = 0) = 0; /** - * \fn virtual void Predictor::UpdatePredictionCache( const gbm::GBTreeModel - * &model, std::vector >* updaters, int - * num_new_trees) = 0; - * - * \brief Update the internal prediction cache using newly added trees. Will - * use the tree updater to do this if possible. Should be called as a part of - * the tree boosting process to facilitate the look up of predictions - * at a later time. - * - * \param model The model. - * \param [in,out] updaters The updater sequence for gradient boosting. - * \param num_new_trees Number of new trees. - */ - - virtual void UpdatePredictionCache( - const gbm::GBTreeModel& model, - std::vector>* updaters, - int num_new_trees, - DMatrix* m, - PredictionCacheEntry* predts) = 0; - - /** - * \fn virtual void Predictor::PredictInstance( const SparsePage::Inst& - * inst, std::vector* out_preds, const gbm::GBTreeModel& model, - * * \brief online prediction function, predict score for one instance at a time * NOTE: use the batch prediction interface if possible, batch prediction is * usually more efficient than online prediction This function is NOT @@ -234,7 +209,6 @@ class Predictor { * * \param name Name of the predictor. * \param generic_param Pointer to runtime parameters. - * \param cache Pointer to prediction cache. */ static Predictor* Create( std::string const& name, GenericParameter const* generic_param); diff --git a/include/xgboost/tree_updater.h b/include/xgboost/tree_updater.h index e24bd39bd91d..2508f130dec2 100644 --- a/include/xgboost/tree_updater.h +++ b/include/xgboost/tree_updater.h @@ -73,6 +73,7 @@ class TreeUpdater : public Configurable { /*! * \brief Create a tree updater given name * \param name Name of the tree updater. + * \param tparam A global runtime parameter */ static TreeUpdater* Create(const std::string& name, GenericParameter const* tparam); }; diff --git a/src/common/hist_util.h b/src/common/hist_util.h index a47eae4ae807..f0442278d7a8 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -199,7 +199,7 @@ class CutsBuilder { } void AddCutPoint(WQSketch::SummaryContainer const& summary, int max_bin) { - int required_cuts = std::min(static_cast(summary.size), max_bin); + size_t required_cuts = std::min(summary.size, static_cast(max_bin)); for (size_t i = 1; i < required_cuts; ++i) { bst_float cpt = summary.data[i].value; if (i == 1 || cpt > p_cuts_->cut_values_.back()) { diff --git a/src/common/transform.h b/src/common/transform.h index bcbdba22edf8..ad4f7b9fc832 100644 --- a/src/common/transform.h +++ b/src/common/transform.h @@ -181,8 +181,7 @@ class Transform { * \param func A callable object, accepting a size_t thread index, * followed by a set of Span classes. * \param range Range object specifying parallel threads index range. - * \param devices GPUSet specifying GPUs to use, when compiling for CPU, - * this should be GPUSet::Empty(). + * \param device Specify GPU to use. * \param shard Whether Shard for HostDeviceVector is needed. */ template diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index c318d860b446..f9bcc71b0b03 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -296,8 +296,15 @@ void GBTree::CommitModel(std::vector>>&& ne num_new_trees += new_trees[gid].size(); model_.CommitModel(std::move(new_trees[gid]), gid); } - CHECK(configured_); - GetPredictor()->UpdatePredictionCache(model_, &updaters_, num_new_trees, m, predts); + auto* out = &predts->predictions; + if (model_.learner_model_param_->num_output_group == 1 && + updaters_.size() > 0 && + num_new_trees == 1 && + out->Size() > 0 && + updaters_.back()->UpdatePredictionCache(m, out)) { + auto delta = num_new_trees / model_.learner_model_param_->num_output_group; + predts->Update(delta); + } monitor_.Stop("CommitModel"); } @@ -357,6 +364,76 @@ void GBTree::SaveModel(Json* p_out) const { model_.SaveModel(&model); } +void GBTree::PredictBatch(DMatrix* p_fmat, + PredictionCacheEntry* out_preds, + bool training, + unsigned ntree_limit) { + CHECK(configured_); + GetPredictor(&out_preds->predictions, p_fmat) + ->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit); +} + +std::unique_ptr const & +GBTree::GetPredictor(HostDeviceVector const *out_pred, + DMatrix *f_dmat) const { + CHECK(configured_); + if (tparam_.predictor != PredictorType::kAuto) { + if (tparam_.predictor == PredictorType::kGPUPredictor) { +#if defined(XGBOOST_USE_CUDA) + CHECK(gpu_predictor_); + return gpu_predictor_; +#else + this->AssertGPUSupport(); +#endif // defined(XGBOOST_USE_CUDA) + } + CHECK(cpu_predictor_); + return cpu_predictor_; + } + + auto on_device = + f_dmat && + (*(f_dmat->GetBatches().begin())).data.DeviceCanRead(); + + // Use GPU Predictor if data is already on device. + if (on_device) { +#if defined(XGBOOST_USE_CUDA) + CHECK(gpu_predictor_); + return gpu_predictor_; +#else + LOG(FATAL) << "Data is on CUDA device, but XGBoost is not compiled with " + "CUDA support."; + return cpu_predictor_; +#endif // defined(XGBOOST_USE_CUDA) + } + + // GPU_Hist by default has prediction cache calculated from quantile values, + // so GPU Predictor is not used for training dataset. But when XGBoost + // performs continue training with an existing model, the prediction cache is + // not availbale and number of trees doesn't equal zero, the whole training + // dataset got copied into GPU for precise prediction. This condition tries + // to avoid such copy by calling CPU Predictor instead. + if ((out_pred && out_pred->Size() == 0) && (model_.param.num_trees != 0) && + // FIXME(trivialfis): Implement a better method for testing whether data + // is on device after DMatrix refactoring is done. + !on_device) { + CHECK(cpu_predictor_); + return cpu_predictor_; + } + + if (tparam_.tree_method == TreeMethod::kGPUHist) { +#if defined(XGBOOST_USE_CUDA) + CHECK(gpu_predictor_); + return gpu_predictor_; +#else + this->AssertGPUSupport(); + return cpu_predictor_; +#endif // defined(XGBOOST_USE_CUDA) + } + + CHECK(cpu_predictor_); + return cpu_predictor_; +} + class Dart : public GBTree { public: explicit Dart(LearnerModelParam const* booster_config) : diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index f05d084d2856..3bee789c62dc 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -201,11 +201,7 @@ class GBTree : public GradientBooster { void PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool training, - unsigned ntree_limit) override { - CHECK(configured_); - GetPredictor(&out_preds->predictions, p_fmat)->PredictBatch( - p_fmat, out_preds, model_, 0, ntree_limit); - } + unsigned ntree_limit) override; void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, @@ -256,62 +252,7 @@ class GBTree : public GradientBooster { std::vector >* ret); std::unique_ptr const& GetPredictor(HostDeviceVector const* out_pred = nullptr, - DMatrix* f_dmat = nullptr) const { - CHECK(configured_); - if (tparam_.predictor != PredictorType::kAuto) { - if (tparam_.predictor == PredictorType::kGPUPredictor) { -#if defined(XGBOOST_USE_CUDA) - CHECK(gpu_predictor_); - return gpu_predictor_; -#else - this->AssertGPUSupport(); -#endif // defined(XGBOOST_USE_CUDA) - } - CHECK(cpu_predictor_); - return cpu_predictor_; - } - - auto on_device = f_dmat && (*(f_dmat->GetBatches().begin())).data.DeviceCanRead(); - - // Use GPU Predictor if data is already on device. - if (on_device) { -#if defined(XGBOOST_USE_CUDA) - CHECK(gpu_predictor_); - return gpu_predictor_; -#else - LOG(FATAL) << "Data is on CUDA device, but XGBoost is not compiled with CUDA support."; - return cpu_predictor_; -#endif // defined(XGBOOST_USE_CUDA) - } - - // GPU_Hist by default has prediction cache calculated from quantile values, so GPU - // Predictor is not used for training dataset. But when XGBoost performs continue - // training with an existing model, the prediction cache is not availbale and number - // of trees doesn't equal zero, the whole training dataset got copied into GPU for - // precise prediction. This condition tries to avoid such copy by calling CPU - // Predictor instead. - if ((out_pred && out_pred->Size() == 0) && - (model_.param.num_trees != 0) && - // FIXME(trivialfis): Implement a better method for testing whether data is on - // device after DMatrix refactoring is done. - !on_device) { - CHECK(cpu_predictor_); - return cpu_predictor_; - } - - if (tparam_.tree_method == TreeMethod::kGPUHist) { -#if defined(XGBOOST_USE_CUDA) - CHECK(gpu_predictor_); - return gpu_predictor_; -#else - this->AssertGPUSupport(); - return cpu_predictor_; -#endif // defined(XGBOOST_USE_CUDA) - } - - CHECK(cpu_predictor_); - return cpu_predictor_; - } + DMatrix* f_dmat = nullptr) const; // commit new trees all at once virtual void CommitModel(std::vector>>&& new_trees, diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 74c955612c2f..6a7666ccfecf 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -146,8 +146,12 @@ class CPUPredictor : public Predictor { CHECK_EQ(tree_begin, 0); auto* out_preds = &predts->predictions; CHECK_GE(predts->version, tree_begin); + if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) { + CHECK_EQ(predts->version, 0); + } if (predts->version == 0) { - CHECK_EQ(out_preds->Size(), 0); + // out_preds->Size() can be non-zero as it's initialized here before any tree is + // built at the 0^th iterator. this->InitOutPredictions(dmat->Info(), out_preds, model); } @@ -185,30 +189,6 @@ class CPUPredictor : public Predictor { out_preds->Size() == dmat->Info().num_row_); } - void UpdatePredictionCache( - const gbm::GBTreeModel& model, - std::vector>* updaters, - int num_new_trees, - DMatrix* m, - PredictionCacheEntry* predts) override { - int old_ntree = model.trees.size() - num_new_trees; - // update cache entry - auto* out = &predts->predictions; - if (predts->predictions.Size() == 0) { - this->InitOutPredictions(m->Info(), out, model); - this->PredInternal(m, &out->HostVector(), model, 0, model.trees.size()); - } else if (model.learner_model_param_->num_output_group == 1 && - updaters->size() > 0 && - num_new_trees == 1 && - updaters->back()->UpdatePredictionCache(m, out)) { - {} - } else { - PredInternal(m, &out->HostVector(), model, old_ntree, model.trees.size()); - } - auto delta = num_new_trees / model.learner_model_param_->num_output_group; - predts->Update(delta); - } - void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit) override { diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index eac4bda96d09..75917ccdca3c 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -7,7 +7,6 @@ #include #include -#include "xgboost/parameter.h" #include "xgboost/data.h" #include "xgboost/predictor.h" #include "xgboost/tree_model.h" @@ -316,8 +315,10 @@ class GPUPredictor : public xgboost::Predictor { CHECK_EQ(tree_begin, 0); auto* out_preds = &predts->predictions; CHECK_GE(predts->version, tree_begin); + if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) { + CHECK_EQ(predts->version, 0); + } if (predts->version == 0) { - CHECK_EQ(out_preds->Size(), 0); this->InitOutPredictions(dmat->Info(), out_preds, model); } @@ -370,32 +371,6 @@ class GPUPredictor : public xgboost::Predictor { } } - void UpdatePredictionCache( - const gbm::GBTreeModel& model, - std::vector>* updaters, - int num_new_trees, - DMatrix* m, - PredictionCacheEntry* predts) override { - int device = generic_param_->gpu_id; - ConfigureDevice(device); - auto old_ntree = model.trees.size() - num_new_trees; - // update cache entry - auto* out = &predts->predictions; - if (predts->predictions.Size() == 0) { - InitOutPredictions(m->Info(), out, model); - DevicePredictInternal(m, out, model, 0, model.trees.size()); - } else if (model.learner_model_param_->num_output_group == 1 && - updaters->size() > 0 && - num_new_trees == 1 && - updaters->back()->UpdatePredictionCache(m, out)) { - {} - } else { - DevicePredictInternal(m, out, model, old_ntree, model.trees.size()); - } - auto delta = num_new_trees / model.learner_model_param_->num_output_group; - predts->Update(delta); - } - void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit) override { diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index 60f920fe9d7d..416b4d82389c 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -1,9 +1,15 @@ #pragma once #include +#include +#include +#include +#include +#include +#include "../../../src/common/hist_util.h" #include "../../../src/data/simple_dmatrix.h" // Some helper functions used to test both GPU and CPU algorithms -// +// namespace xgboost { namespace common { @@ -13,8 +19,8 @@ inline std::vector GenerateRandom(int num_rows, int num_columns) { std::mt19937 rng(0); std::uniform_real_distribution dist(0.0, 1.0); std::generate(x.begin(), x.end(), [&]() { return dist(rng); }); - for (auto i = 0ull; i < num_columns; i++) { - for (auto j = 0ull; j < num_rows; j++) { + for (auto i = 0; i < num_columns; i++) { + for (auto j = 0; j < num_rows; j++) { x[j * num_columns + i] += i; } } @@ -22,14 +28,13 @@ inline std::vector GenerateRandom(int num_rows, int num_columns) { } inline std::vector GenerateRandomCategoricalSingleColumn(int n, - int num_categories) { + int num_categories) { std::vector x(n); std::mt19937 rng(0); std::uniform_int_distribution dist(0, num_categories - 1); std::generate(x.begin(), x.end(), [&]() { return dist(rng); }); // Make sure each category is present - for(auto i = 0ull; i < num_categories; i++) - { + for(auto i = 0; i < num_categories; i++) { x[i] = i; } return x; @@ -47,9 +52,9 @@ inline std::shared_ptr GetExternalMemoryDMatrixFromData( // Create the svm file in a temp dir const std::string tmp_file = tempdir.path + "/temp.libsvm"; std::ofstream fo(tmp_file.c_str()); - for (auto i = 0ull; i < num_rows; i++) { + for (auto i = 0; i < num_rows; i++) { std::stringstream row_data; - for (auto j = 0ull; j < num_columns; j++) { + for (auto j = 0; j < num_columns; j++) { row_data << 1 << " " << j << ":" << std::setprecision(15) << x[i * num_columns + j]; } @@ -62,7 +67,7 @@ inline std::shared_ptr GetExternalMemoryDMatrixFromData( // Test that elements are approximately equally distributed among bins inline void TestBinDistribution(const HistogramCuts& cuts, int column_idx, - const std::vector& column, + const std::vector& column, int num_bins) { std::map counts; for (auto& v : column) { @@ -144,11 +149,11 @@ inline void ValidateColumn(const HistogramCuts& cuts, int column_idx, // x is dense and row major inline void ValidateCuts(const HistogramCuts& cuts, std::vector& x, int num_rows, int num_columns, - int num_bins) { - for (auto i = 0ull; i < num_columns; i++) { + int num_bins) { + for (auto i = 0; i < num_columns; i++) { // Extract the column std::vector column(num_rows); - for (auto j = 0ull; j < num_rows; j++) { + for (auto j = 0; j < num_rows; j++) { column[j] = x[j*num_columns + i]; } ValidateColumn(cuts,i, column, num_bins);