diff --git a/CMakeLists.txt b/CMakeLists.txt index c7075928e577..12f6f06e6421 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,7 +49,7 @@ option(USE_SANITIZER "Use santizer flags" OFF) option(SANITIZER_PATH "Path to sanitizes.") set(ENABLED_SANITIZERS "address" "leak" CACHE STRING "Semicolon separated list of sanitizer names. E.g 'address;leak'. Supported sanitizers are -address, leak and thread.") +address, leak, undefined and thread.") ## Plugins option(PLUGIN_LZ4 "Build lz4 plugin" OFF) option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF) diff --git a/R-package/R/xgb.Booster.R b/R-package/R/xgb.Booster.R index f18632500207..660264e0b7b7 100644 --- a/R-package/R/xgb.Booster.R +++ b/R-package/R/xgb.Booster.R @@ -139,6 +139,8 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) { #' @param reshape whether to reshape the vector of predictions to a matrix form when there are several #' prediction outputs per case. This option has no effect when either of predleaf, predcontrib, #' or predinteraction flags is TRUE. +#' @param training whether is the prediction result used for training. For dart booster, +#' training predicting will perform dropout. #' @param ... Parameters passed to \code{predict.xgb.Booster} #' #' @details diff --git a/R-package/man/agaricus.test.Rd b/R-package/man/agaricus.test.Rd index 041ff4e6c813..b88b340966dc 100644 --- a/R-package/man/agaricus.test.Rd +++ b/R-package/man/agaricus.test.Rd @@ -4,7 +4,7 @@ \name{agaricus.test} \alias{agaricus.test} \title{Test part from Mushroom Data Set} -\format{A list containing a label vector, and a dgCMatrix object with 1611 +\format{A list containing a label vector, and a dgCMatrix object with 1611 rows and 126 variables} \usage{ data(agaricus.test) @@ -24,8 +24,8 @@ This data set includes the following fields: \references{ https://archive.ics.uci.edu/ml/datasets/Mushroom -Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository -[http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, +Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository +[http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, School of Information and Computer Science. } \keyword{datasets} diff --git a/R-package/man/agaricus.train.Rd b/R-package/man/agaricus.train.Rd index 0c08e8080de1..6df609699dd9 100644 --- a/R-package/man/agaricus.train.Rd +++ b/R-package/man/agaricus.train.Rd @@ -4,7 +4,7 @@ \name{agaricus.train} \alias{agaricus.train} \title{Training part from Mushroom Data Set} -\format{A list containing a label vector, and a dgCMatrix object with 6513 +\format{A list containing a label vector, and a dgCMatrix object with 6513 rows and 127 variables} \usage{ data(agaricus.train) @@ -24,8 +24,8 @@ This data set includes the following fields: \references{ https://archive.ics.uci.edu/ml/datasets/Mushroom -Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository -[http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, +Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository +[http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, School of Information and Computer Science. } \keyword{datasets} diff --git a/R-package/man/predict.xgb.Booster.Rd b/R-package/man/predict.xgb.Booster.Rd index 69b48cd15bba..6430eabf5c63 100644 --- a/R-package/man/predict.xgb.Booster.Rd +++ b/R-package/man/predict.xgb.Booster.Rd @@ -49,6 +49,9 @@ It will use all the trees by default (\code{NULL} value).} prediction outputs per case. This option has no effect when either of predleaf, predcontrib, or predinteraction flags is TRUE.} +\item{training}{whether is the prediction result used for training. For dart booster, +training predicting will perform dropout.} + \item{...}{Parameters passed to \code{predict.xgb.Booster}} } \value{ diff --git a/R-package/tests/testthat/test_custom_objective.R b/R-package/tests/testthat/test_custom_objective.R index 79d8eccf8795..5e40a9b8a8b1 100644 --- a/R-package/tests/testthat/test_custom_objective.R +++ b/R-package/tests/testthat/test_custom_objective.R @@ -31,7 +31,6 @@ num_round <- 2 test_that("custom objective works", { bst <- xgb.train(param, dtrain, num_round, watchlist) expect_equal(class(bst), "xgb.Booster") - expect_equal(length(bst$raw), 1100) expect_false(is.null(bst$evaluation_log)) expect_false(is.null(bst$evaluation_log$eval_error)) expect_lt(bst$evaluation_log[num_round, eval_error], 0.03) @@ -58,5 +57,4 @@ test_that("custom objective using DMatrix attr works", { param$objective = logregobjattr bst <- xgb.train(param, dtrain, num_round, watchlist) expect_equal(class(bst), "xgb.Booster") - expect_equal(length(bst$raw), 1100) }) diff --git a/doc/python/convert_090to100.py b/doc/python/convert_090to100.py new file mode 100644 index 000000000000..135489b09d36 --- /dev/null +++ b/doc/python/convert_090to100.py @@ -0,0 +1,79 @@ +'''This is a simple script that converts a pickled XGBoost +Scikit-Learn interface object from 0.90 to a native model. Pickle +format is not stable as it's a direct serialization of Python object. +We advice not to use it when stability is needed. + +''' +import pickle +import json +import os +import argparse +import numpy as np +import xgboost +import warnings + + +def save_label_encoder(le): + '''Save the label encoder in XGBClassifier''' + meta = dict() + for k, v in le.__dict__.items(): + if isinstance(v, np.ndarray): + meta[k] = v.tolist() + else: + meta[k] = v + return meta + + +def xgboost_skl_90to100(skl_model): + '''Extract the model and related metadata in SKL model.''' + model = {} + with open(skl_model, 'rb') as fd: + old = pickle.load(fd) + if not isinstance(old, xgboost.XGBModel): + raise TypeError( + 'The script only handes Scikit-Learn interface object') + + # Save Scikit-Learn specific Python attributes into a JSON document. + for k, v in old.__dict__.items(): + if k == '_le': + model[k] = save_label_encoder(v) + elif k == 'classes_': + model[k] = v.tolist() + elif k == '_Booster': + continue + else: + try: + json.dumps({k: v}) + model[k] = v + except TypeError: + warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.') + booster = old.get_booster() + # Store the JSON serialization as an attribute + booster.set_attr(scikit_learn=json.dumps(model)) + + # Save it into a native model. + i = 0 + while True: + path = 'xgboost_native_model_from_' + skl_model + '-' + str(i) + '.bin' + if os.path.exists(path): + i += 1 + continue + booster.save_model(path) + break + + +if __name__ == '__main__': + assert xgboost.__version__ != '1.0.0', ('Please use the XGBoost version' + ' that generates this pickle.') + parser = argparse.ArgumentParser( + description=('A simple script to convert pickle generated by' + ' XGBoost 0.90 to XGBoost 1.0.0 model (not pickle).')) + parser.add_argument( + '--old-pickle', + type=str, + help='Path to old pickle file of Scikit-Learn interface object. ' + 'Will output a native model converted from this pickle file', + required=True) + args = parser.parse_args() + + xgboost_skl_90to100(args.old_pickle) diff --git a/doc/tutorials/saving_model.rst b/doc/tutorials/saving_model.rst index aa3b41e6b598..7d416ccb1bbc 100644 --- a/doc/tutorials/saving_model.rst +++ b/doc/tutorials/saving_model.rst @@ -91,7 +91,12 @@ Loading pickled file from different version of XGBoost As noted, pickled model is neither portable nor stable, but in some cases the pickled models are valuable. One way to restore it in the future is to load it back with that -specific version of Python and XGBoost, export the model by calling `save_model`. +specific version of Python and XGBoost, export the model by calling `save_model`. To help +easing the mitigation, we created a simple script for converting pickled XGBoost 0.90 +Scikit-Learn interface object to XGBoost 1.0.0 native model. Please note that the script +suits simple use cases, and it's advised not to use pickle when stability is needed. +It's located in ``xgboost/doc/python`` with the name ``convert_090to100.py``. See +comments in the script for more details. ******************************************************** Saving and Loading the internal parameters configuration diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index 7dd0a2f2d448..dbfbc1e34643 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -208,6 +208,8 @@ struct LearnerModelParam { // As the old `LearnerModelParamLegacy` is still used by binary IO, we keep // this one as an immutable copy. LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin); + /* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */ + bool Initialized() const { return num_feature != 0; } }; } // namespace xgboost diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 3b5c8ff594f9..fde1b1657663 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -600,6 +600,7 @@ def fit(self, results = train(self.client, params, dtrain, num_boost_round=self.get_num_boosting_rounds(), evals=evals) + # pylint: disable=attribute-defined-outside-init self._Booster = results['booster'] # pylint: disable=attribute-defined-outside-init self.evals_result_ = results['history'] diff --git a/src/common/json.cc b/src/common/json.cc index ecdcce3d3623..52878bbf9dcc 100644 --- a/src/common/json.cc +++ b/src/common/json.cc @@ -24,7 +24,7 @@ void JsonWriter::Visit(JsonArray const* arr) { for (size_t i = 0; i < size; ++i) { auto const& value = vec[i]; this->Save(value); - if (i != size-1) { Write(", "); } + if (i != size-1) { Write(","); } } this->Write("]"); } @@ -38,7 +38,7 @@ void JsonWriter::Visit(JsonObject const* obj) { size_t size = obj->getObject().size(); for (auto& value : obj->getObject()) { - this->Write("\"" + value.first + "\": "); + this->Write("\"" + value.first + "\":"); this->Save(value.second); if (i != size-1) { diff --git a/src/learner.cc b/src/learner.cc index 10b7882c6900..753e10844429 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -67,19 +67,26 @@ struct LearnerModelParamLegacy : public dmlc::Parameter /* \brief global bias */ bst_float base_score; /* \brief number of features */ - unsigned num_feature; + uint32_t num_feature; /* \brief number of classes, if it is multi-class classification */ - int num_class; + int32_t num_class; /*! \brief Model contain additional properties */ - int contain_extra_attrs; + int32_t contain_extra_attrs; /*! \brief Model contain eval metrics */ - int contain_eval_metrics; + int32_t contain_eval_metrics; + /*! \brief the version of XGBoost. */ + uint32_t major_version; + uint32_t minor_version; /*! \brief reserved field */ - int reserved[29]; + int reserved[27]; /*! \brief constructor */ LearnerModelParamLegacy() { std::memset(this, 0, sizeof(LearnerModelParamLegacy)); base_score = 0.5f; + major_version = std::get<0>(Version::Self()); + minor_version = std::get<1>(Version::Self()); + static_assert(sizeof(LearnerModelParamLegacy) == 136, + "Do not change the size of this struct, as it will break binary IO."); } // Skip other legacy fields. Json ToJson() const { @@ -117,8 +124,9 @@ LearnerModelParam::LearnerModelParam( LearnerModelParamLegacy const &user_param, float base_margin) : base_score{base_margin}, num_feature{user_param.num_feature}, num_output_group{user_param.num_class == 0 - ? 1 - : static_cast(user_param.num_class)} {} + ? 1 + : static_cast(user_param.num_class)} +{} struct LearnerTrainParam : public XGBoostParameter { // data split mode, can be row, col, or none. @@ -140,7 +148,7 @@ struct LearnerTrainParam : public XGBoostParameter { .describe("Data split mode for distributed training."); DMLC_DECLARE_FIELD(disable_default_eval_metric) .set_default(0) - .describe("flag to disable default metric. Set to >0 to disable"); + .describe("Flag to disable default metric. Set to >0 to disable"); DMLC_DECLARE_FIELD(booster) .set_default("gbtree") .describe("Gradient booster used for training."); @@ -200,6 +208,7 @@ class LearnerImpl : public Learner { Args args = {cfg_.cbegin(), cfg_.cend()}; tparam_.UpdateAllowUnknown(args); + auto mparam_backup = mparam_; mparam_.UpdateAllowUnknown(args); generic_parameters_.UpdateAllowUnknown(args); generic_parameters_.CheckDeprecated(); @@ -217,17 +226,33 @@ class LearnerImpl : public Learner { // set seed only before the model is initialized common::GlobalRandom().seed(generic_parameters_.seed); + // must precede configure gbm since num_features is required for gbm this->ConfigureNumFeatures(); args = {cfg_.cbegin(), cfg_.cend()}; // renew this->ConfigureObjective(old_tparam, &args); - this->ConfigureGBM(old_tparam, args); - this->ConfigureMetrics(args); + // Before 1.0.0, we save `base_score` into binary as a transformed value by objective. + // After 1.0.0 we save the value provided by user and keep it immutable instead. To + // keep the stability, we initialize it in binary LoadModel instead of configuration. + // Under what condition should we omit the transformation: + // + // - base_score is loaded from old binary model. + // + // What are the other possible conditions: + // + // - model loaded from new binary or JSON. + // - model is created from scratch. + // - model is configured second time due to change of parameter + if (!learner_model_param_.Initialized() || mparam_.base_score != mparam_backup.base_score) { + learner_model_param_ = LearnerModelParam(mparam_, + obj_->ProbToMargin(mparam_.base_score)); + } + + this->ConfigureGBM(old_tparam, args); generic_parameters_.ConfigureGpuId(this->gbm_->UseGPU()); - learner_model_param_ = LearnerModelParam(mparam_, - obj_->ProbToMargin(mparam_.base_score)); + this->ConfigureMetrics(args); this->need_configuration_ = false; if (generic_parameters_.validate_parameters) { @@ -337,9 +362,6 @@ class LearnerImpl : public Learner { cache_)); gbm_->LoadModel(gradient_booster); - learner_model_param_ = LearnerModelParam(mparam_, - obj_->ProbToMargin(mparam_.base_score)); - auto const& j_attributes = get(learner.at("attributes")); attributes_.clear(); for (auto const& kv : j_attributes) { @@ -459,6 +481,7 @@ class LearnerImpl : public Learner { } if (header[0] == '{') { + // Dispatch to JSON auto json_stream = common::FixedSizeStream(&fp); std::string buffer; json_stream.Take(&buffer); @@ -471,25 +494,10 @@ class LearnerImpl : public Learner { // read parameter CHECK_EQ(fi->Read(&mparam_, sizeof(mparam_)), sizeof(mparam_)) << "BoostLearner: wrong model format"; - { - // backward compatibility code for compatible with old model type - // for new model, Read(&name_obj_) is suffice - uint64_t len; - CHECK_EQ(fi->Read(&len, sizeof(len)), sizeof(len)); - if (len >= std::numeric_limits::max()) { - int gap; - CHECK_EQ(fi->Read(&gap, sizeof(gap)), sizeof(gap)) - << "BoostLearner: wrong model format"; - len = len >> static_cast(32UL); - } - if (len != 0) { - tparam_.objective.resize(len); - CHECK_EQ(fi->Read(&tparam_.objective[0], len), len) - << "BoostLearner: wrong model format"; - } - } + + CHECK(fi->Read(&tparam_.objective)) << "BoostLearner: wrong model format"; CHECK(fi->Read(&tparam_.booster)) << "BoostLearner: wrong model format"; - // duplicated code with LazyInitModel + obj_.reset(ObjFunction::Create(tparam_.objective, &generic_parameters_)); gbm_.reset(GradientBooster::Create(tparam_.booster, &generic_parameters_, &learner_model_param_, cache_)); @@ -508,34 +516,57 @@ class LearnerImpl : public Learner { } attributes_ = std::map(attr.begin(), attr.end()); } - if (tparam_.objective == "count:poisson") { - std::string max_delta_step; - fi->Read(&max_delta_step); - cfg_["max_delta_step"] = max_delta_step; + bool warn_old_model { false }; + if (attributes_.find("count_poisson_max_delta_step") != attributes_.cend()) { + // Loading model from < 1.0.0, objective is not saved. + cfg_["max_delta_step"] = attributes_.at("count_poisson_max_delta_step"); + attributes_.erase("count_poisson_max_delta_step"); + warn_old_model = true; + } else { + warn_old_model = false; + } + + if (mparam_.major_version >= 1) { + learner_model_param_ = LearnerModelParam(mparam_, + obj_->ProbToMargin(mparam_.base_score)); + } else { + // Before 1.0.0, base_score is saved as a transformed value, and there's no version + // attribute in the saved model. + learner_model_param_ = LearnerModelParam(mparam_, mparam_.base_score); + warn_old_model = true; } - if (mparam_.contain_eval_metrics != 0) { - std::vector metr; - fi->Read(&metr); - for (auto name : metr) { - metrics_.emplace_back(Metric::Create(name, &generic_parameters_)); + if (attributes_.find("objective") != attributes_.cend()) { + auto obj_str = attributes_.at("objective"); + auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()}); + obj_->LoadConfig(j_obj); + attributes_.erase("objective"); + } else { + warn_old_model = true; + } + if (attributes_.find("metrics") != attributes_.cend()) { + auto metrics_str = attributes_.at("metrics"); + std::vector names { common::Split(metrics_str, ';') }; + attributes_.erase("metrics"); + for (auto const& n : names) { + this->SetParam(kEvalMetric, n); } } + if (warn_old_model) { + LOG(WARNING) << "Loading model from XGBoost < 1.0.0, consider saving it " + "again for improved compatibility"; + } + + // Renew the version. + mparam_.major_version = std::get<0>(Version::Self()); + mparam_.minor_version = std::get<1>(Version::Self()); + cfg_["num_class"] = common::ToString(mparam_.num_class); cfg_["num_feature"] = common::ToString(mparam_.num_feature); auto n = tparam_.__DICT__(); cfg_.insert(n.cbegin(), n.cend()); - Args args = {cfg_.cbegin(), cfg_.cend()}; - generic_parameters_.UpdateAllowUnknown(args); - gbm_->Configure(args); - obj_->Configure({cfg_.begin(), cfg_.end()}); - - for (auto& p_metric : metrics_) { - p_metric->Configure({cfg_.begin(), cfg_.end()}); - } - // copy dsplit from config since it will not run again during restore if (tparam_.dsplit == DataSplitMode::kAuto && rabit::IsDistributed()) { tparam_.dsplit = DataSplitMode::kRow; @@ -552,15 +583,8 @@ class LearnerImpl : public Learner { void SaveModel(dmlc::Stream* fo) const override { LearnerModelParamLegacy mparam = mparam_; // make a copy to potentially modify std::vector > extra_attr; - // extra attributed to be added just before saving - if (tparam_.objective == "count:poisson") { - auto it = cfg_.find("max_delta_step"); - if (it != cfg_.end()) { - // write `max_delta_step` parameter as extra attribute of booster - mparam.contain_extra_attrs = 1; - extra_attr.emplace_back("count_poisson_max_delta_step", it->second); - } - } + mparam.contain_extra_attrs = 1; + { std::vector saved_params; // check if rabit_bootstrap_cache were set to non zero before adding to checkpoint @@ -577,6 +601,24 @@ class LearnerImpl : public Learner { } } } + { + // Similar to JSON model IO, we save the objective. + Json j_obj { Object() }; + obj_->SaveConfig(&j_obj); + std::string obj_doc; + Json::Dump(j_obj, &obj_doc); + extra_attr.emplace_back("objective", obj_doc); + } + // As of 1.0.0, JVM Package and R Package uses Save/Load model for serialization. + // Remove this part once they are ported to use actual serialization methods. + if (mparam.contain_eval_metrics != 0) { + std::stringstream os; + for (auto& ev : metrics_) { + os << ev->Name() << ";"; + } + extra_attr.emplace_back("metrics", os.str()); + } + fo->Write(&mparam, sizeof(LearnerModelParamLegacy)); fo->Write(tparam_.objective); fo->Write(tparam_.booster); @@ -587,26 +629,7 @@ class LearnerImpl : public Learner { attr[kv.first] = kv.second; } fo->Write(std::vector>( - attr.begin(), attr.end())); - } - if (tparam_.objective == "count:poisson") { - auto it = cfg_.find("max_delta_step"); - if (it != cfg_.end()) { - fo->Write(it->second); - } else { - // recover value of max_delta_step from extra attributes - auto it2 = attributes_.find("count_poisson_max_delta_step"); - const std::string max_delta_step - = (it2 != attributes_.end()) ? it2->second : kMaxDeltaStepDefaultValue; - fo->Write(max_delta_step); - } - } - if (mparam.contain_eval_metrics != 0) { - std::vector metr; - for (auto& ev : metrics_) { - metr.emplace_back(ev->Name()); - } - fo->Write(metr); + attr.begin(), attr.end())); } } @@ -661,11 +684,13 @@ class LearnerImpl : public Learner { If you are loading a serialized model (like pickle in Python) generated by older XGBoost, please export the model by calling `Booster.save_model` from that version - first, then load it back in current version. See: + first, then load it back in current version. There's a simple script for helping + the process. See: https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html - for more details about differences between saving model and serializing. + for reference to the script, and more details about differences between saving model and + serializing. )doc"; int64_t sz {-1}; @@ -854,7 +879,8 @@ class LearnerImpl : public Learner { void ConfigureObjective(LearnerTrainParam const& old, Args* p_args) { // Once binary IO is gone, NONE of these config is useful. - if (cfg_.find("num_class") != cfg_.cend() && cfg_.at("num_class") != "0") { + if (cfg_.find("num_class") != cfg_.cend() && cfg_.at("num_class") != "0" && + tparam_.objective != "multi:softprob") { cfg_["num_output_group"] = cfg_["num_class"]; if (atoi(cfg_["num_class"].c_str()) > 1 && cfg_.count("objective") == 0) { tparam_.objective = "multi:softmax"; @@ -919,7 +945,6 @@ class LearnerImpl : public Learner { } CHECK_NE(mparam_.num_feature, 0) << "0 feature is supplied. Are you using raw Booster interface?"; - learner_model_param_.num_feature = mparam_.num_feature; // Remove these once binary IO is gone. cfg_["num_feature"] = common::ToString(mparam_.num_feature); cfg_["num_class"] = common::ToString(mparam_.num_class); diff --git a/tests/ci_build/Dockerfile.cpu b/tests/ci_build/Dockerfile.cpu index 7f85010466d0..39b209c8e4b1 100644 --- a/tests/ci_build/Dockerfile.cpu +++ b/tests/ci_build/Dockerfile.cpu @@ -21,8 +21,9 @@ ENV GOSU_VERSION 1.10 # Install Python packages RUN \ - pip install pyyaml cpplint pylint astroid sphinx numpy scipy pandas matplotlib sh recommonmark guzzle_sphinx_theme mock \ - breathe matplotlib graphviz pytest scikit-learn wheel kubernetes urllib3 jsonschema && \ + pip install pyyaml cpplint pylint astroid sphinx numpy scipy pandas matplotlib sh \ + recommonmark guzzle_sphinx_theme mock breathe matplotlib graphviz \ + pytest scikit-learn wheel kubernetes urllib3 jsonschema boto3 && \ pip install https://h2o-release.s3.amazonaws.com/datatable/stable/datatable-0.7.0/datatable-0.7.0-cp37-cp37m-linux_x86_64.whl && \ pip install "dask[complete]" diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index d21c634a93c5..111a75028a5c 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -183,6 +183,41 @@ TEST(Learner, JsonModelIO) { delete pp_dmat; } +TEST(Learner, BinaryModelIO) { + size_t constexpr kRows = 8; + int32_t constexpr kIters = 4; + auto pp_dmat = CreateDMatrix(kRows, 10, 0); + std::shared_ptr p_dmat {*pp_dmat}; + p_dmat->Info().labels_.Resize(kRows); + + std::unique_ptr learner{Learner::Create({p_dmat})}; + learner->SetParam("eval_metric", "rmsle"); + learner->Configure(); + for (int32_t iter = 0; iter < kIters; ++iter) { + learner->UpdateOneIter(iter, p_dmat); + } + dmlc::TemporaryDirectory tempdir; + std::string const fname = tempdir.path + "binary_model_io.bin"; + { + // Make sure the write is complete before loading. + std::unique_ptr fo(dmlc::Stream::Create(fname.c_str(), "w")); + learner->SaveModel(fo.get()); + } + + learner.reset(Learner::Create({p_dmat})); + std::unique_ptr fi(dmlc::Stream::Create(fname.c_str(), "r")); + learner->LoadModel(fi.get()); + learner->Configure(); + Json config { Object() }; + learner->SaveConfig(&config); + std::string config_str; + Json::Dump(config, &config_str); + ASSERT_NE(config_str.find("rmsle"), std::string::npos); + ASSERT_EQ(config_str.find("WARNING"), std::string::npos); + + delete pp_dmat; +} + #if defined(XGBOOST_USE_CUDA) // Tests for automatic GPU configuration. TEST(Learner, GPUConfiguration) { diff --git a/tests/python/generate_models.py b/tests/python/generate_models.py new file mode 100644 index 000000000000..6376d802e659 --- /dev/null +++ b/tests/python/generate_models.py @@ -0,0 +1,148 @@ +import xgboost +import numpy as np +import os + +kRounds = 2 +kRows = 1000 +kCols = 4 +kForests = 2 +kMaxDepth = 2 +kClasses = 3 + +X = np.random.randn(kRows, kCols) +w = np.random.uniform(size=kRows) + +version = xgboost.__version__ + +np.random.seed(1994) +target_dir = 'models' + + +def booster_bin(model): + return os.path.join(target_dir, + 'xgboost-' + version + '.' + model + '.bin') + + +def booster_json(model): + return os.path.join(target_dir, + 'xgboost-' + version + '.' + model + '.json') + + +def skl_bin(model): + return os.path.join(target_dir, + 'xgboost_scikit-' + version + '.' + model + '.bin') + + +def skl_json(model): + return os.path.join(target_dir, + 'xgboost_scikit-' + version + '.' + model + '.json') + + +def generate_regression_model(): + print('Regression') + y = np.random.randn(kRows) + + data = xgboost.DMatrix(X, label=y, weight=w) + booster = xgboost.train({'tree_method': 'hist', + 'num_parallel_tree': kForests, + 'max_depth': kMaxDepth}, + num_boost_round=kRounds, dtrain=data) + booster.save_model(booster_bin('reg')) + booster.save_model(booster_json('reg')) + + reg = xgboost.XGBRegressor(tree_method='hist', + num_parallel_tree=kForests, + max_depth=kMaxDepth, + n_estimators=kRounds) + reg.fit(X, y, w) + reg.save_model(skl_bin('reg')) + reg.save_model(skl_json('reg')) + + +def generate_logistic_model(): + print('Logistic') + y = np.random.randint(0, 2, size=kRows) + assert y.max() == 1 and y.min() == 0 + + data = xgboost.DMatrix(X, label=y, weight=w) + booster = xgboost.train({'tree_method': 'hist', + 'num_parallel_tree': kForests, + 'max_depth': kMaxDepth, + 'objective': 'binary:logistic'}, + num_boost_round=kRounds, dtrain=data) + booster.save_model(booster_bin('logit')) + booster.save_model(booster_json('logit')) + + reg = xgboost.XGBClassifier(tree_method='hist', + num_parallel_tree=kForests, + max_depth=kMaxDepth, + n_estimators=kRounds) + reg.fit(X, y, w) + reg.save_model(skl_bin('logit')) + reg.save_model(skl_json('logit')) + + +def generate_classification_model(): + print('Classification') + y = np.random.randint(0, kClasses, size=kRows) + data = xgboost.DMatrix(X, label=y, weight=w) + booster = xgboost.train({'num_class': kClasses, + 'tree_method': 'hist', + 'num_parallel_tree': kForests, + 'max_depth': kMaxDepth}, + num_boost_round=kRounds, dtrain=data) + booster.save_model(booster_bin('cls')) + booster.save_model(booster_json('cls')) + + cls = xgboost.XGBClassifier(tree_method='hist', + num_parallel_tree=kForests, + max_depth=kMaxDepth, + n_estimators=kRounds) + cls.fit(X, y, w) + cls.save_model(skl_bin('cls')) + cls.save_model(skl_json('cls')) + + +def generate_ranking_model(): + print('Learning to Rank') + y = np.random.randint(5, size=kRows) + w = np.random.uniform(size=20) + g = np.repeat(50, 20) + + data = xgboost.DMatrix(X, y, weight=w) + data.set_group(g) + booster = xgboost.train({'objective': 'rank:ndcg', + 'num_parallel_tree': kForests, + 'tree_method': 'hist', + 'max_depth': kMaxDepth}, + num_boost_round=kRounds, + dtrain=data) + booster.save_model(booster_bin('ltr')) + booster.save_model(booster_json('ltr')) + + ranker = xgboost.sklearn.XGBRanker(n_estimators=kRounds, + tree_method='hist', + objective='rank:ndcg', + max_depth=kMaxDepth, + num_parallel_tree=kForests) + ranker.fit(X, y, g, sample_weight=w) + ranker.save_model(skl_bin('ltr')) + ranker.save_model(skl_json('ltr')) + + +def write_versions(): + versions = {'numpy': np.__version__, + 'xgboost': version} + with open(os.path.join(target_dir, 'version'), 'w') as fd: + fd.write(str(versions)) + + +if __name__ == '__main__': + if not os.path.exists(target_dir): + os.mkdir(target_dir) + + generate_regression_model() + generate_logistic_model() + generate_classification_model() + generate_ranking_model() + write_versions() diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index d27c50d71d5c..2c3237000a52 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -39,7 +39,7 @@ class TestBasic(unittest.TestCase): def test_basic(self): dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') - param = {'max_depth': 2, 'eta': 1, 'verbosity': 0, + param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'} # specify validations set to watch performance watchlist = [(dtest, 'eval'), (dtrain, 'train')] diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index a72071b20d52..a5eb395dff27 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -284,16 +284,31 @@ def test_feature_names_validation(self): self.assertRaises(ValueError, bst.predict, dm1) bst.predict(dm2) # success + def test_model_binary_io(self): + model_path = 'test_model_binary_io.bin' + parameters = {'tree_method': 'hist', 'booster': 'gbtree', + 'scale_pos_weight': '0.5'} + X = np.random.random((10, 3)) + y = np.random.random((10,)) + dtrain = xgb.DMatrix(X, y) + bst = xgb.train(parameters, dtrain, num_boost_round=2) + bst.save_model(model_path) + bst = xgb.Booster(model_file=model_path) + os.remove(model_path) + config = json.loads(bst.save_config()) + assert float(config['learner']['objective'][ + 'reg_loss_param']['scale_pos_weight']) == 0.5 + def test_model_json_io(self): - model_path = './model.json' + model_path = 'test_model_json_io.json' parameters = {'tree_method': 'hist', 'booster': 'gbtree'} j_model = json_model(model_path, parameters) assert isinstance(j_model['learner'], dict) - bst = xgb.Booster(model_file='./model.json') + bst = xgb.Booster(model_file=model_path) bst.save_model(fname=model_path) - with open('./model.json', 'r') as fd: + with open(model_path, 'r') as fd: j_model = json.load(fd) assert isinstance(j_model['learner'], dict) @@ -302,7 +317,7 @@ def test_model_json_io(self): @pytest.mark.skipif(**tm.no_json_schema()) def test_json_schema(self): import jsonschema - model_path = './model.json' + model_path = 'test_json_schema.json' path = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) doc = os.path.join(path, 'doc', 'model.schema') diff --git a/tests/python/test_model_compatibility.py b/tests/python/test_model_compatibility.py new file mode 100644 index 000000000000..3ab85c74be8a --- /dev/null +++ b/tests/python/test_model_compatibility.py @@ -0,0 +1,130 @@ +import xgboost +import os +import generate_models as gm +import json +import zipfile +import pytest + + +def run_model_param_check(config): + assert config['learner']['learner_model_param']['num_feature'] == str(4) + assert config['learner']['learner_train_param']['booster'] == 'gbtree' + + +def run_booster_check(booster, name): + config = json.loads(booster.save_config()) + run_model_param_check(config) + if name.find('cls') != -1: + assert (len(booster.get_dump()) == gm.kForests * gm.kRounds * + gm.kClasses) + assert float( + config['learner']['learner_model_param']['base_score']) == 0.5 + assert config['learner']['learner_train_param'][ + 'objective'] == 'multi:softmax' + elif name.find('logit') != -1: + assert len(booster.get_dump()) == gm.kForests * gm.kRounds + assert config['learner']['learner_model_param']['num_class'] == str(0) + assert config['learner']['learner_train_param'][ + 'objective'] == 'binary:logistic' + elif name.find('ltr') != -1: + assert config['learner']['learner_train_param'][ + 'objective'] == 'rank:ndcg' + else: + assert name.find('reg') != -1 + assert len(booster.get_dump()) == gm.kForests * gm.kRounds + assert float( + config['learner']['learner_model_param']['base_score']) == 0.5 + assert config['learner']['learner_train_param'][ + 'objective'] == 'reg:squarederror' + + +def run_scikit_model_check(name, path): + if name.find('reg') != -1: + reg = xgboost.XGBRegressor() + reg.load_model(path) + config = json.loads(reg.get_booster().save_config()) + if name.find('0.90') != -1: + assert config['learner']['learner_train_param'][ + 'objective'] == 'reg:linear' + else: + assert config['learner']['learner_train_param'][ + 'objective'] == 'reg:squarederror' + assert (len(reg.get_booster().get_dump()) == + gm.kRounds * gm.kForests) + run_model_param_check(config) + elif name.find('cls') != -1: + cls = xgboost.XGBClassifier() + cls.load_model(path) + if name.find('0.90') == -1: + assert len(cls.classes_) == gm.kClasses + assert len(cls._le.classes_) == gm.kClasses + assert cls.n_classes_ == gm.kClasses + assert (len(cls.get_booster().get_dump()) == + gm.kRounds * gm.kForests * gm.kClasses), path + config = json.loads(cls.get_booster().save_config()) + assert config['learner']['learner_train_param'][ + 'objective'] == 'multi:softprob', path + run_model_param_check(config) + elif name.find('ltr') != -1: + ltr = xgboost.XGBRanker() + ltr.load_model(path) + assert (len(ltr.get_booster().get_dump()) == + gm.kRounds * gm.kForests) + config = json.loads(ltr.get_booster().save_config()) + assert config['learner']['learner_train_param'][ + 'objective'] == 'rank:ndcg' + run_model_param_check(config) + elif name.find('logit') != -1: + logit = xgboost.XGBClassifier() + logit.load_model(path) + assert (len(logit.get_booster().get_dump()) == + gm.kRounds * gm.kForests) + config = json.loads(logit.get_booster().save_config()) + assert config['learner']['learner_train_param'][ + 'objective'] == 'binary:logistic' + else: + assert False + + +@pytest.mark.ci +def test_model_compatibility(): + '''Test model compatibility, can only be run on CI as others don't + have the credentials. + + ''' + path = os.path.dirname(os.path.abspath(__file__)) + path = os.path.join(path, 'models') + try: + import boto3 + import botocore + except ImportError: + pytest.skip( + 'Skiping compatibility tests as boto3 is not installed.') + + try: + s3_bucket = boto3.resource('s3').Bucket('xgboost-ci-jenkins-artifacts') + zip_path = 'xgboost_model_compatibility_test.zip' + s3_bucket.download_file(zip_path, zip_path) + except botocore.exceptions.NoCredentialsError: + pytest.skip( + 'Skiping compatibility tests as running on non-CI environment.') + + with zipfile.ZipFile(zip_path, 'r') as z: + z.extractall(path) + + models = [ + os.path.join(root, f) for root, subdir, files in os.walk(path) + for f in files + if f != 'version' + ] + assert models + + for path in models: + name = os.path.basename(path) + if name.startswith('xgboost-'): + booster = xgboost.Booster(model_file=path) + run_booster_check(booster, name) + elif name.startswith('xgboost_scikit'): + run_scikit_model_check(name, path) + else: + assert False diff --git a/tests/python/test_ranking.py b/tests/python/test_ranking.py index 23a5180737f6..39de5d5cac55 100644 --- a/tests/python/test_ranking.py +++ b/tests/python/test_ranking.py @@ -115,7 +115,6 @@ def setUpClass(cls): # model training parameters cls.params = {'objective': 'rank:pairwise', 'booster': 'gbtree', - 'silent': 0, 'eval_metric': ['ndcg'] } @@ -143,7 +142,7 @@ def test_cv(self): Test cross-validation with a group specified """ cv = xgboost.cv(self.params, self.dtrain, num_boost_round=2500, - early_stopping_rounds=10, nfold=10, as_pandas=False) + early_stopping_rounds=10, nfold=10, as_pandas=False) assert isinstance(cv, dict) self.assertSetEqual(set(cv.keys()), {'test-ndcg-mean', 'train-ndcg-mean', 'test-ndcg-std', 'train-ndcg-std'}, "CV results dict key mismatch") @@ -153,7 +152,8 @@ def test_cv_no_shuffle(self): Test cross-validation with a group specified """ cv = xgboost.cv(self.params, self.dtrain, num_boost_round=2500, - early_stopping_rounds=10, shuffle=False, nfold=10, as_pandas=False) + early_stopping_rounds=10, shuffle=False, nfold=10, + as_pandas=False) assert isinstance(cv, dict) assert len(cv) == 4