diff --git a/doc/model.schema b/doc/model.schema index d90049caa2c9..86acea967f1e 100644 --- a/doc/model.schema +++ b/doc/model.schema @@ -97,7 +97,7 @@ "default_left": { "type": "array", "items": { - "type": "boolean" + "type": "integer" } }, "categories": { diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 9d657872e26b..993bde63223e 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -1081,14 +1081,32 @@ XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle, const void *buf, bst_ulong len); + /*! - * \brief save model into binary raw bytes, return header of the array - * user must copy the result out, before next xgboost call + * \brief Save model into raw bytes, return header of the array. User must copy the + * result out, before next xgboost call + * * \param handle handle - * \param out_len the argument to hold the output length - * \param out_dptr the argument to hold the output data pointer + * \param json_config JSON encoded string storing parameters for the function. Following + * keys are expected in the JSON document: + * + * "format": str + * - json: Output booster will be encoded as JSON. + * - ubj: Output booster will be encoded as Univeral binary JSON. + * - deprecated: Output booster will be encoded as old custom binary format. Do not use + * this format except for compatibility reasons. + * + * \param out_len The argument to hold the output length + * \param out_dptr The argument to hold the output data pointer + * * \return 0 when success, -1 when failure happens */ +XGB_DLL int XGBoosterSaveModelToBuffer(BoosterHandle handle, char const *json_config, + bst_ulong *out_len, char const **out_dptr); + +/*! + * \brief Deprecated, use `XGBoosterSaveModelToBuffer` instead. + */ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle, bst_ulong *out_len, const char **out_dptr); diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 2d1ed7f125b4..23859ac22b0d 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -1,5 +1,5 @@ /*! - * Copyright (c) 2015-2021 by Contributors + * Copyright (c) 2015-2022 by Contributors * \file data.h * \brief The input data structure of xgboost. * \author Tianqi Chen @@ -36,10 +36,7 @@ enum class DataType : uint8_t { kStr = 5 }; -enum class FeatureType : uint8_t { - kNumerical, - kCategorical -}; +enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 }; /*! * \brief Meta information about dataset, always sit in memory. diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index 78a18c044f61..897a7330189d 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -1,5 +1,5 @@ /*! - * Copyright 2021 by XGBoost Contributors + * Copyright 2021-2022 by XGBoost Contributors * \file linalg.h * \brief Linear algebra related utilities. */ @@ -567,7 +567,7 @@ template Json ArrayInterface(TensorView const &t) { Json array_interface{Object{}}; array_interface["data"] = std::vector(2); - array_interface["data"][0] = Integer(reinterpret_cast(t.Values().data())); + array_interface["data"][0] = Integer{reinterpret_cast(t.Values().data())}; array_interface["data"][1] = Boolean{true}; if (t.DeviceIdx() >= 0) { // Change this once we have different CUDA stream. diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 670bef938e9c..5c23f278a83b 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -1,5 +1,5 @@ /*! - * Copyright 2014-2019 by Contributors + * Copyright 2014-2022 by Contributors * \file tree_model.h * \brief model structure for tree * \author Tianqi Chen @@ -42,7 +42,7 @@ struct TreeParam : public dmlc::Parameter { /*! \brief maximum depth, this is a statistics of the tree */ int deprecated_max_depth; /*! \brief number of features used for tree construction */ - int num_feature; + bst_feature_t num_feature; /*! * \brief leaf vector size, used for vector tree * used to store more than one dimensional information in tree @@ -629,6 +629,7 @@ class RegTree : public Model { } private: + template void LoadCategoricalSplit(Json const& in); void SaveCategoricalSplit(Json* p_out) const; // vector of nodes diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 98a2759c88c5..d469c2ab957f 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2014-2021 by Contributors +// Copyright (c) 2014-2022 by Contributors #include #include @@ -248,22 +248,16 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data, #endif // Create from data iterator -XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, - DMatrixHandle proxy, - DataIterResetCallback *reset, - XGDMatrixCallbackNext *next, - char const* c_json_config, - DMatrixHandle *out) { +XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy, + DataIterResetCallback *reset, XGDMatrixCallbackNext *next, + char const *c_json_config, DMatrixHandle *out) { API_BEGIN(); auto config = Json::Load(StringView{c_json_config}); - float missing = get(config["missing"]); - std::string cache = get(config["cache_prefix"]); - int32_t n_threads = omp_get_max_threads(); - if (!IsA(config["nthread"])) { - n_threads = get(config["nthread"]); - } - *out = new std::shared_ptr{xgboost::DMatrix::Create( - iter, proxy, reset, next, missing, n_threads, cache)}; + auto missing = GetMissing(config); + std::string cache = RequiredArg(config, "cache_prefix", __func__); + auto n_threads = OptionalArg(config, "nthread", common::OmpGetNumThreads(0)); + *out = new std::shared_ptr{ + xgboost::DMatrix::Create(iter, proxy, reset, next, missing, n_threads, cache)}; API_END(); } @@ -358,8 +352,8 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, StringView{data}, ncol); auto config = Json::Load(StringView{c_json_config}); float missing = GetMissing(config); - auto nthread = get(config["nthread"]); - *out = new std::shared_ptr(DMatrix::Create(&adapter, missing, nthread)); + auto n_threads = OptionalArg(config, "nthread", common::OmpGetNumThreads(0)); + *out = new std::shared_ptr(DMatrix::Create(&adapter, missing, n_threads)); API_END(); } @@ -371,9 +365,9 @@ XGB_DLL int XGDMatrixCreateFromDense(char const *data, xgboost::data::ArrayAdapter(StringView{data})}; auto config = Json::Load(StringView{c_json_config}); float missing = GetMissing(config); - auto nthread = get(config["nthread"]); + auto n_threads = OptionalArg(config, "nthread", common::OmpGetNumThreads(0)); *out = - new std::shared_ptr(DMatrix::Create(&adapter, missing, nthread)); + new std::shared_ptr(DMatrix::Create(&adapter, missing, n_threads)); API_END(); } @@ -765,11 +759,11 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle, auto& entry = learner->GetThreadLocal().prediction_entry; auto p_m = *static_cast *>(dmat); - auto const& j_config = get(config); - auto type = PredictionType(get(j_config.at("type"))); - auto iteration_begin = get(j_config.at("iteration_begin")); - auto iteration_end = get(j_config.at("iteration_end")); + auto type = PredictionType(RequiredArg(config, "type", __func__)); + auto iteration_begin = RequiredArg(config, "iteration_begin", __func__); + auto iteration_end = RequiredArg(config, "iteration_end", __func__); + auto const& j_config = get(config); auto ntree_limit_it = j_config.find("ntree_limit"); if (ntree_limit_it != j_config.cend() && !IsA(ntree_limit_it->second) && get(ntree_limit_it->second) != 0) { @@ -785,7 +779,7 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle, type == PredictionType::kApproxContribution; bool interactions = type == PredictionType::kInteraction || type == PredictionType::kApproxInteraction; - bool training = get(config["training"]); + bool training = RequiredArg(config, "training", __func__); learner->Predict(p_m, type == PredictionType::kMargin, &entry.predictions, iteration_begin, iteration_end, training, type == PredictionType::kLeaf, contribs, approximate, @@ -796,7 +790,7 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle, auto rounds = iteration_end - iteration_begin; rounds = rounds == 0 ? learner->BoostedRounds() : rounds; // Determine shape - bool strict_shape = get(config["strict_shape"]); + bool strict_shape = RequiredArg(config, "strict_shape", __func__); CalcPredictShape(strict_shape, type, p_m->Info().num_row_, p_m->Info().num_col_, chunksize, learner->Groups(), rounds, &shape, out_dim); @@ -814,15 +808,15 @@ void InplacePredictImpl(std::shared_ptr x, std::shared_ptr p_m, CHECK_EQ(get(config["cache_id"]), 0) << "Cache ID is not supported yet"; HostDeviceVector* p_predt { nullptr }; - auto type = PredictionType(get(config["type"])); + auto type = PredictionType(RequiredArg(config, "type", __func__)); float missing = GetMissing(config); learner->InplacePredict(x, p_m, type, missing, &p_predt, - get(config["iteration_begin"]), - get(config["iteration_end"])); + RequiredArg(config, "iteration_begin", __func__), + RequiredArg(config, "iteration_end", __func__)); CHECK(p_predt); auto &shape = learner->GetThreadLocal().prediction_shape; auto chunksize = n_rows == 0 ? 0 : p_predt->Size() / n_rows; - bool strict_shape = get(config["strict_shape"]); + bool strict_shape = RequiredArg(config, "strict_shape", __func__); CalcPredictShape(strict_shape, type, n_rows, n_cols, chunksize, learner->Groups(), learner->BoostedRounds(), &shape, out_dim); *out_result = dmlc::BeginPtr(p_predt->HostVector()); @@ -900,12 +894,21 @@ XGB_DLL int XGBoosterPredictFromCUDAColumnar( XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) { API_BEGIN(); CHECK_HANDLE(); - if (common::FileExtension(fname) == "json") { + auto read_file = [&]() { auto str = common::LoadSequentialFile(fname); - CHECK_GT(str.size(), 2); + CHECK_GE(str.size(), 3); // "{}\0" CHECK_EQ(str[0], '{'); - Json in { Json::Load({str.c_str(), str.size()}) }; + CHECK_EQ(str[str.size() - 2], '}'); + return str; + }; + if (common::FileExtension(fname) == "json") { + auto str = read_file(); + Json in{Json::Load(StringView{str})}; static_cast(handle)->LoadModel(in); + } else if (common::FileExtension(fname) == "ubj") { + auto str = read_file(); + Json in = Json::Load(StringView{str}, std::ios::binary); + static_cast(handle)->LoadModel(in); } else { std::unique_ptr fi(dmlc::Stream::Create(fname, "r")); static_cast(handle)->LoadModel(fi.get()); @@ -913,32 +916,83 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) { API_END(); } -XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char* c_fname) { +namespace { +void WarnOldModel() { + if (XGBOOST_VER_MAJOR >= 2) { + LOG(WARNING) << "Saving into deprecated binary model format, please consider using `json` or " + "`ubj`. Model format will default to JSON in XGBoost 2.2 if not specified."; + } +} +} // anonymous namespace + +XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char *c_fname) { API_BEGIN(); CHECK_HANDLE(); std::unique_ptr fo(dmlc::Stream::Create(c_fname, "w")); auto *learner = static_cast(handle); learner->Configure(); - if (common::FileExtension(c_fname) == "json") { - Json out { Object() }; + auto save_json = [&](std::ios::openmode mode) { + Json out{Object()}; learner->SaveModel(&out); - std::string str; - Json::Dump(out, &str); - fo->Write(str.c_str(), str.size()); + std::vector str; + Json::Dump(out, &str, mode); + fo->Write(str.data(), str.size()); + }; + if (common::FileExtension(c_fname) == "json") { + save_json(std::ios::out); + } else if (common::FileExtension(c_fname) == "ubj") { + save_json(std::ios::binary); + } else if (XGBOOST_VER_MAJOR == 2 && XGBOOST_VER_MINOR >= 2) { + LOG(WARNING) << "Saving model to JSON as default. You can use file extension `json`, `ubj` or " + "`deprecated` to choose between formats."; + save_json(std::ios::out); } else { - auto *bst = static_cast(handle); + WarnOldModel(); + auto *bst = static_cast(handle); bst->SaveModel(fo.get()); } API_END(); } -XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle, - const void* buf, +XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle, const void *buf, xgboost::bst_ulong len) { API_BEGIN(); CHECK_HANDLE(); - common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*) - static_cast(handle)->LoadModel(&fs); + common::MemoryFixSizeBuffer fs((void *)buf, len); // NOLINT(*) + static_cast(handle)->LoadModel(&fs); + API_END(); +} + +XGB_DLL int XGBoosterSaveModelToBuffer(BoosterHandle handle, char const *json_config, + xgboost::bst_ulong *out_len, char const **out_dptr) { + API_BEGIN(); + CHECK_HANDLE(); + auto config = Json::Load(StringView{json_config}); + auto format = RequiredArg(config, "format", __func__); + + auto *learner = static_cast(handle); + std::string &raw_str = learner->GetThreadLocal().ret_str; + raw_str.clear(); + + learner->Configure(); + Json out{Object{}}; + if (format == "json") { + learner->SaveModel(&out); + Json::Dump(out, &raw_str); + } else if (format == "ubj") { + learner->SaveModel(&out); + Json::Dump(out, &raw_str, std::ios::binary); + } else if (format == "deprecated") { + WarnOldModel(); + common::MemoryBufferStream fo(&raw_str); + learner->SaveModel(&fo); + } else { + LOG(FATAL) << "Unknown format: `" << format << "`"; + } + + *out_dptr = dmlc::BeginPtr(raw_str); + *out_len = static_cast(raw_str.length()); + API_END(); } @@ -952,6 +1006,8 @@ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle, raw_str.resize(0); common::MemoryBufferStream fo(&raw_str); + LOG(WARNING) << "`" << __func__ + << "` is deprecated, please use `XGBoosterSaveModelToBuffer` instead."; learner->Configure(); learner->SaveModel(&fo); @@ -1208,7 +1264,8 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config, CHECK_HANDLE(); auto *learner = static_cast(handle); auto config = Json::Load(StringView{json_config}); - auto importance = get(config["importance_type"]); + + auto importance = RequiredArg(config, "importance_type", __func__); std::string feature_map_uri; if (!IsA(config["feature_map"])) { feature_map_uri = get(config["feature_map"]); diff --git a/src/c_api/c_api_utils.h b/src/c_api/c_api_utils.h index f29c302a1686..1a004bcd2f0f 100644 --- a/src/c_api/c_api_utils.h +++ b/src/c_api/c_api_utils.h @@ -1,5 +1,5 @@ /*! - * Copyright (c) 2021 by XGBoost Contributors + * Copyright (c) 2021-2022 by XGBoost Contributors */ #ifndef XGBOOST_C_API_C_API_UTILS_H_ #define XGBOOST_C_API_C_API_UTILS_H_ @@ -241,5 +241,25 @@ inline void GenerateFeatureMap(Learner const *learner, } void XGBBuildInfoDevice(Json* p_info); + +template +auto const &RequiredArg(Json const &in, std::string const &key, StringView func) { + auto const &obj = get(in); + auto it = obj.find(key); + if (it == obj.cend() || IsA(it->second)) { + LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`"; + } + return get const>(it->second); +} + +template +auto const &OptionalArg(Json const &in, std::string const &key, T const &dft) { + auto const &obj = get(in); + auto it = obj.find(key); + if (it != obj.cend()) { + return get const>(it->second); + } + return dft; +} } // namespace xgboost #endif // XGBOOST_C_API_C_API_UTILS_H_ diff --git a/src/common/config.h b/src/common/config.h index 8070c67286d8..c8b98eb77f99 100644 --- a/src/common/config.h +++ b/src/common/config.h @@ -111,7 +111,7 @@ class ConfigParser { const auto last_char = str.find_last_not_of(" \t\n\r"); if (first_char == std::string::npos) { // Every character in str is a whitespace - return std::string(); + return {}; } CHECK_NE(last_char, std::string::npos); const auto substr_len = last_char + 1 - first_char; diff --git a/src/common/io.cc b/src/common/io.cc index 01bc1e36b3c1..8405e6604037 100644 --- a/src/common/io.cc +++ b/src/common/io.cc @@ -1,5 +1,5 @@ /*! - * Copyright (c) by XGBoost Contributors 2019 + * Copyright (c) by XGBoost Contributors 2019-2022 */ #if defined(__unix__) #include @@ -52,7 +52,7 @@ size_t PeekableInStream::PeekRead(void* dptr, size_t size) { FixedSizeStream::FixedSizeStream(PeekableInStream* stream) : PeekableInStream(stream), pointer_{0} { size_t constexpr kInitialSize = 4096; - size_t size {kInitialSize}, total {0}; + size_t size{kInitialSize}, total{0}; buffer_.clear(); while (true) { buffer_.resize(size); @@ -142,5 +142,18 @@ std::string LoadSequentialFile(std::string uri, bool stream) { buffer.resize(total); return buffer; } + +std::string FileExtension(std::string fname, bool lower) { + if (lower) { + std::transform(fname.begin(), fname.end(), fname.begin(), + [](char c) { return std::tolower(c); }); + } + auto splited = Split(fname, '.'); + if (splited.size() > 1) { + return splited.back(); + } else { + return ""; + } +} } // namespace common } // namespace xgboost diff --git a/src/common/io.h b/src/common/io.h index 70ad111243a2..b377623ea138 100644 --- a/src/common/io.h +++ b/src/common/io.h @@ -1,5 +1,5 @@ /*! - * Copyright 2014 by Contributors + * Copyright by XGBoost Contributors 2014-2022 * \file io.h * \brief general stream interface for serialization, I/O * \author Tianqi Chen @@ -86,15 +86,31 @@ class FixedSizeStream : public PeekableInStream { */ std::string LoadSequentialFile(std::string uri, bool stream = false); -inline std::string FileExtension(std::string const& fname) { - auto splited = Split(fname, '.'); - if (splited.size() > 1) { - return splited.back(); +/** + * \brief Get file extension from file name. + * + * \param lower Return in lower case. + * + * \return File extension without the `.` + */ +std::string FileExtension(std::string fname, bool lower = true); + +/** + * \brief Read the whole buffer from dmlc stream. + */ +inline std::string ReadAll(dmlc::Stream* fi, PeekableInStream* fp) { + std::string buffer; + if (auto fixed_size = dynamic_cast(fi)) { + fixed_size->Seek(common::MemoryFixSizeBuffer::kSeekEnd); + size_t size = fixed_size->Tell(); + buffer.resize(size); + fixed_size->Seek(0); + CHECK_EQ(fixed_size->Read(&buffer[0], size), size); } else { - return ""; + FixedSizeStream{fp}.Take(&buffer); } + return buffer; } - } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_IO_H_ diff --git a/src/gbm/gblinear_model.cc b/src/gbm/gblinear_model.cc index 3a86cc07be07..5e6f5dda9a1f 100644 --- a/src/gbm/gblinear_model.cc +++ b/src/gbm/gblinear_model.cc @@ -1,6 +1,7 @@ /*! - * Copyright 2019-2021 by Contributors + * Copyright 2019-2022 by Contributors */ +#include #include #include #include "xgboost/json.h" @@ -13,22 +14,28 @@ void GBLinearModel::SaveModel(Json* p_out) const { auto& out = *p_out; size_t const n_weights = weight.size(); - std::vector j_weights(n_weights); - for (size_t i = 0; i < n_weights; ++i) { - j_weights[i] = weight[i]; - } + F32Array j_weights{n_weights}; + std::copy(weight.begin(), weight.end(), j_weights.GetArray().begin()); out["weights"] = std::move(j_weights); out["boosted_rounds"] = Json{this->num_boosted_rounds}; } void GBLinearModel::LoadModel(Json const& in) { - auto const& j_weights = get(in["weights"]); - auto n_weights = j_weights.size(); - weight.resize(n_weights); - for (size_t i = 0; i < n_weights; ++i) { - weight[i] = get(j_weights[i]); - } auto const& obj = get(in); + auto weight_it = obj.find("weights"); + if (IsA(weight_it->second)) { + auto const& j_weights = get(weight_it->second); + weight.resize(j_weights.size()); + std::copy(j_weights.begin(), j_weights.end(), weight.begin()); + } else { + auto const& j_weights = get(weight_it->second); + auto n_weights = j_weights.size(); + weight.resize(n_weights); + for (size_t i = 0; i < n_weights; ++i) { + weight[i] = get(j_weights[i]); + } + } + auto boosted_rounds = obj.find("boosted_rounds"); if (boosted_rounds != obj.cend()) { this->num_boosted_rounds = get(boosted_rounds->second); diff --git a/src/gbm/gbtree_model.cc b/src/gbm/gbtree_model.cc index e56dc0ad3a59..80659dab244d 100644 --- a/src/gbm/gbtree_model.cc +++ b/src/gbm/gbtree_model.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2019-2020 by Contributors + * Copyright 2019-2022 by Contributors */ #include @@ -69,13 +69,13 @@ void GBTreeModel::SaveModel(Json* p_out) const { out["gbtree_model_param"] = ToJson(param); std::vector trees_json(trees.size()); - for (size_t t = 0; t < trees.size(); ++t) { + common::ParallelFor(trees.size(), omp_get_max_threads(), [&](auto t) { auto const& tree = trees[t]; Json tree_json{Object()}; tree->SaveModel(&tree_json); - tree_json["id"] = Integer(static_cast(t)); + tree_json["id"] = Integer{static_cast(t)}; trees_json[t] = std::move(tree_json); - } + }); std::vector tree_info_json(tree_info.size()); for (size_t i = 0; i < tree_info.size(); ++i) { @@ -95,11 +95,11 @@ void GBTreeModel::LoadModel(Json const& in) { auto const& trees_json = get(in["trees"]); trees.resize(trees_json.size()); - for (size_t t = 0; t < trees_json.size(); ++t) { // NOLINT + common::ParallelFor(trees_json.size(), omp_get_max_threads(), [&](auto t) { auto tree_id = get(trees_json[t]["id"]); trees.at(tree_id).reset(new RegTree()); trees.at(tree_id)->LoadModel(trees_json[t]); - } + }); tree_info.resize(param.num_trees); auto const& tree_info_json = get(in["tree_info"]); diff --git a/src/learner.cc b/src/learner.cc index 9659784128c1..6551bf1ae744 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2014-2021 by Contributors + * Copyright 2014-2022 by Contributors * \file learner.cc * \brief Implementation of learning algorithm. * \author Tianqi Chen @@ -706,6 +706,21 @@ class LearnerConfiguration : public Learner { std::string const LearnerConfiguration::kEvalMetric {"eval_metric"}; // NOLINT +namespace { +StringView ModelMsg() { + return StringView{ + R"doc( + If you are loading a serialized model (like pickle in Python, RDS in R) generated by + older XGBoost, please export the model by calling `Booster.save_model` from that version + first, then load it back in current version. See: + + https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html + + for more details about differences between saving model and serializing. +)doc"}; +} +} // anonymous namespace + class LearnerIO : public LearnerConfiguration { private: std::set saved_configs_ = {"num_round"}; @@ -714,12 +729,17 @@ class LearnerIO : public LearnerConfiguration { std::string const serialisation_header_ { u8"CONFIG-offset:" }; public: - explicit LearnerIO(std::vector > cache) : - LearnerConfiguration{cache} {} + explicit LearnerIO(std::vector> cache) : LearnerConfiguration{cache} {} void LoadModel(Json const& in) override { CHECK(IsA(in)); - Version::Load(in); + auto version = Version::Load(in); + if (std::get<0>(version) == 1 && std::get<1>(version) < 6) { + LOG(WARNING) + << "Found JSON model saved before XGBoost 1.6, please save the model using current " + "version again. The support for old JSON model will be discontinued in XGBoost 2.3."; + } + auto const& learner = get(in["learner"]); mparam_.FromJson(learner.at("learner_model_param")); @@ -733,8 +753,8 @@ class LearnerIO : public LearnerConfiguration { auto const& gradient_booster = learner.at("gradient_booster"); name = get(gradient_booster["name"]); tparam_.UpdateAllowUnknown(Args{{"booster", name}}); - gbm_.reset(GradientBooster::Create(tparam_.booster, - &generic_parameters_, &learner_model_param_)); + gbm_.reset( + GradientBooster::Create(tparam_.booster, &generic_parameters_, &learner_model_param_)); gbm_->LoadModel(gradient_booster); auto const& j_attributes = get(learner.at("attributes")); @@ -746,20 +766,17 @@ class LearnerIO : public LearnerConfiguration { // feature names and types are saved in xgboost 1.4 auto it = learner.find("feature_names"); if (it != learner.cend()) { - auto const &feature_names = get(it->second); - feature_names_.clear(); - for (auto const &name : feature_names) { - feature_names_.emplace_back(get(name)); - } + auto const& feature_names = get(it->second); + feature_names_.resize(feature_names.size()); + std::transform(feature_names.cbegin(), feature_names.cend(), feature_names_.begin(), + [](Json const& fn) { return get(fn); }); } it = learner.find("feature_types"); if (it != learner.cend()) { - auto const &feature_types = get(it->second); - feature_types_.clear(); - for (auto const &name : feature_types) { - auto type = get(name); - feature_types_.emplace_back(type); - } + auto const& feature_types = get(it->second); + feature_types_.resize(feature_types.size()); + std::transform(feature_types.cbegin(), feature_types.cend(), feature_types_.begin(), + [](Json const& fn) { return get(fn); }); } this->need_configuration_ = true; @@ -799,6 +816,7 @@ class LearnerIO : public LearnerConfiguration { feature_types.emplace_back(type); } } + // About to be deprecated by JSON format void LoadModel(dmlc::Stream* fi) override { generic_parameters_.UpdateAllowUnknown(Args{}); @@ -817,15 +835,20 @@ class LearnerIO : public LearnerConfiguration { } } - if (header[0] == '{') { - // Dispatch to JSON - auto json_stream = common::FixedSizeStream(&fp); - std::string buffer; - json_stream.Take(&buffer); - auto model = Json::Load({buffer.c_str(), buffer.size()}); + if (header[0] == '{') { // Dispatch to JSON + auto buffer = common::ReadAll(fi, &fp); + Json model; + if (header[1] == '"') { + model = Json::Load(StringView{buffer}); + } else if (std::isalpha(header[1])) { + model = Json::Load(StringView{buffer}, std::ios::binary); + } else { + LOG(FATAL) << "Invalid model format"; + } this->LoadModel(model); return; } + // use the peekable reader. fi = &fp; // read parameter @@ -983,45 +1006,46 @@ class LearnerIO : public LearnerConfiguration { void Save(dmlc::Stream* fo) const override { Json memory_snapshot{Object()}; memory_snapshot["Model"] = Object(); - auto &model = memory_snapshot["Model"]; + auto& model = memory_snapshot["Model"]; this->SaveModel(&model); memory_snapshot["Config"] = Object(); - auto &config = memory_snapshot["Config"]; + auto& config = memory_snapshot["Config"]; this->SaveConfig(&config); - std::string out_str; - Json::Dump(memory_snapshot, &out_str); - fo->Write(out_str.c_str(), out_str.size()); + + std::vector stream; + Json::Dump(memory_snapshot, &stream, std::ios::binary); + fo->Write(stream.data(), stream.size()); } void Load(dmlc::Stream* fi) override { common::PeekableInStream fp(fi); - char c {0}; - fp.PeekRead(&c, 1); - if (c == '{') { - std::string buffer; - common::FixedSizeStream{&fp}.Take(&buffer); - auto memory_snapshot = Json::Load({buffer.c_str(), buffer.size()}); - this->LoadModel(memory_snapshot["Model"]); - this->LoadConfig(memory_snapshot["Config"]); + char header[2]; + fp.PeekRead(header, 2); + if (header[0] == '{') { + auto buffer = common::ReadAll(fi, &fp); + Json memory_snapshot; + if (header[1] == '"') { + memory_snapshot = Json::Load(StringView{buffer}); + LOG(WARNING) << ModelMsg(); + } else if (std::isalpha(header[1])) { + memory_snapshot = Json::Load(StringView{buffer}, std::ios::binary); + } else { + LOG(FATAL) << "Invalid serialization file."; + } + if (IsA(memory_snapshot["Model"])) { + // R has xgb.load that doesn't distinguish whether configuration is saved. + // We should migrate to use `xgb.load.raw` instead. + this->LoadModel(memory_snapshot); + } else { + this->LoadModel(memory_snapshot["Model"]); + this->LoadConfig(memory_snapshot["Config"]); + } } else { std::string header; header.resize(serialisation_header_.size()); CHECK_EQ(fp.Read(&header[0], header.size()), serialisation_header_.size()); // Avoid printing the content in loaded header, which might be random binary code. - CHECK(header == serialisation_header_) // NOLINT - << R"doc( - - If you are loading a serialized model (like pickle in Python) generated by older - XGBoost, please export the model by calling `Booster.save_model` from that version - first, then load it back in current version. There's a simple script for helping - the process. See: - - https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html - - for reference to the script, and more details about differences between saving model and - serializing. - -)doc"; + CHECK(header == serialisation_header_) << ModelMsg(); int64_t sz {-1}; CHECK_EQ(fp.Read(&sz, sizeof(sz)), sizeof(sz)); if (!DMLC_IO_NO_ENDIAN_SWAP) { diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 96efb4d68d11..07e0f1439574 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2015-2021 by Contributors + * Copyright 2015-2022 by Contributors * \file tree_model.cc * \brief model structure for tree */ @@ -893,27 +893,57 @@ void RegTree::Save(dmlc::Stream* fo) const { } } } +// typed array, not boolean +template +std::enable_if_t::value && !std::is_same::value, T> GetElem( + std::vector const& arr, size_t i) { + return arr[i]; +} +// typed array boolean +template +std::enable_if_t::value && std::is_same::value && + std::is_same::value, + bool> +GetElem(std::vector const& arr, size_t i) { + return arr[i] == 1; +} +// json array +template +std::enable_if_t< + std::is_same::value, + std::conditional_t::value, int64_t, + std::conditional_t::value, bool, float>>> +GetElem(std::vector const& arr, size_t i) { + if (std::is_same::value && !IsA(arr[i])) { + return get(arr[i]) == 1; + } + return get(arr[i]); +} +template void RegTree::LoadCategoricalSplit(Json const& in) { - auto const& categories_segments = get(in["categories_segments"]); - auto const& categories_sizes = get(in["categories_sizes"]); - auto const& categories_nodes = get(in["categories_nodes"]); - auto const& categories = get(in["categories"]); + using I64ArrayT = std::conditional_t; + using I32ArrayT = std::conditional_t; + + auto const& categories_segments = get(in["categories_segments"]); + auto const& categories_sizes = get(in["categories_sizes"]); + auto const& categories_nodes = get(in["categories_nodes"]); + auto const& categories = get(in["categories"]); size_t cnt = 0; bst_node_t last_cat_node = -1; if (!categories_nodes.empty()) { - last_cat_node = get(categories_nodes[cnt]); + last_cat_node = GetElem(categories_nodes, cnt); } for (bst_node_t nidx = 0; nidx < param.num_nodes; ++nidx) { if (nidx == last_cat_node) { - auto j_begin = get(categories_segments[cnt]); - auto j_end = get(categories_sizes[cnt]) + j_begin; + auto j_begin = GetElem(categories_segments, cnt); + auto j_end = GetElem(categories_sizes, cnt) + j_begin; bst_cat_t max_cat{std::numeric_limits::min()}; CHECK_NE(j_end - j_begin, 0) << nidx; for (auto j = j_begin; j < j_end; ++j) { - auto const &category = get(categories[j]); + auto const& category = GetElem(categories, j); auto cat = common::AsCat(category); max_cat = std::max(max_cat, cat); } @@ -924,7 +954,7 @@ void RegTree::LoadCategoricalSplit(Json const& in) { std::vector cat_bits_storage(size, 0); common::CatBitField cat_bits{common::Span(cat_bits_storage)}; for (auto j = j_begin; j < j_end; ++j) { - cat_bits.Set(common::AsCat(get(categories[j]))); + cat_bits.Set(common::AsCat(GetElem(categories, j))); } auto begin = split_categories_.size(); @@ -936,9 +966,9 @@ void RegTree::LoadCategoricalSplit(Json const& in) { ++cnt; if (cnt == categories_nodes.size()) { - last_cat_node = -1; + last_cat_node = -1; // Don't break, we still need to initialize the remaining nodes. } else { - last_cat_node = get(categories_nodes[cnt]); + last_cat_node = GetElem(categories_nodes, cnt); } } else { split_categories_segments_[nidx].beg = categories.size(); @@ -947,104 +977,144 @@ void RegTree::LoadCategoricalSplit(Json const& in) { } } +template void RegTree::LoadCategoricalSplit(Json const& in); +template void RegTree::LoadCategoricalSplit(Json const& in); + void RegTree::SaveCategoricalSplit(Json* p_out) const { auto& out = *p_out; CHECK_EQ(this->split_types_.size(), param.num_nodes); CHECK_EQ(this->GetSplitCategoriesPtr().size(), param.num_nodes); - std::vector categories_segments; - std::vector categories_sizes; - std::vector categories; - std::vector categories_nodes; + I64Array categories_segments; + I64Array categories_sizes; + I32Array categories; // bst_cat_t = int32_t + I32Array categories_nodes; // bst_note_t = int32_t for (size_t i = 0; i < nodes_.size(); ++i) { if (this->split_types_[i] == FeatureType::kCategorical) { - categories_nodes.emplace_back(i); - auto begin = categories.size(); - categories_segments.emplace_back(static_cast(begin)); + categories_nodes.GetArray().emplace_back(i); + auto begin = categories.Size(); + categories_segments.GetArray().emplace_back(begin); auto segment = split_categories_segments_[i]; - auto node_categories = - this->GetSplitCategories().subspan(segment.beg, segment.size); + auto node_categories = this->GetSplitCategories().subspan(segment.beg, segment.size); common::KCatBitField const cat_bits(node_categories); for (size_t i = 0; i < cat_bits.Size(); ++i) { if (cat_bits.Check(i)) { - categories.emplace_back(static_cast(i)); + categories.GetArray().emplace_back(i); } } - size_t size = categories.size() - begin; - categories_sizes.emplace_back(static_cast(size)); + size_t size = categories.Size() - begin; + categories_sizes.GetArray().emplace_back(size); CHECK_NE(size, 0); } } - out["categories_segments"] = categories_segments; - out["categories_sizes"] = categories_sizes; - out["categories_nodes"] = categories_nodes; - out["categories"] = categories; + out["categories_segments"] = std::move(categories_segments); + out["categories_sizes"] = std::move(categories_sizes); + out["categories_nodes"] = std::move(categories_nodes); + out["categories"] = std::move(categories); } -void RegTree::LoadModel(Json const& in) { - FromJson(in["tree_param"], ¶m); - auto n_nodes = param.num_nodes; +template , + typename U8ArrayT = std::conditional_t, + typename I32ArrayT = std::conditional_t, + typename I64ArrayT = std::conditional_t, + typename IndexArrayT = std::conditional_t> +bool LoadModelImpl(Json const& in, TreeParam* param, std::vector* p_stats, + std::vector* p_split_types, std::vector* p_nodes, + std::vector* p_split_categories_segments) { + auto& stats = *p_stats; + auto& split_types = *p_split_types; + auto& nodes = *p_nodes; + auto& split_categories_segments = *p_split_categories_segments; + + FromJson(in["tree_param"], param); + auto n_nodes = param->num_nodes; CHECK_NE(n_nodes, 0); // stats - auto const& loss_changes = get(in["loss_changes"]); + auto const& loss_changes = get(in["loss_changes"]); CHECK_EQ(loss_changes.size(), n_nodes); - auto const& sum_hessian = get(in["sum_hessian"]); + auto const& sum_hessian = get(in["sum_hessian"]); CHECK_EQ(sum_hessian.size(), n_nodes); - auto const& base_weights = get(in["base_weights"]); + auto const& base_weights = get(in["base_weights"]); CHECK_EQ(base_weights.size(), n_nodes); // nodes - auto const& lefts = get(in["left_children"]); + auto const& lefts = get(in["left_children"]); CHECK_EQ(lefts.size(), n_nodes); - auto const& rights = get(in["right_children"]); + auto const& rights = get(in["right_children"]); CHECK_EQ(rights.size(), n_nodes); - auto const& parents = get(in["parents"]); + auto const& parents = get(in["parents"]); CHECK_EQ(parents.size(), n_nodes); - auto const& indices = get(in["split_indices"]); + auto const& indices = get(in["split_indices"]); CHECK_EQ(indices.size(), n_nodes); - auto const& conds = get(in["split_conditions"]); + auto const& conds = get(in["split_conditions"]); CHECK_EQ(conds.size(), n_nodes); - auto const& default_left = get(in["default_left"]); + auto const& default_left = get(in["default_left"]); CHECK_EQ(default_left.size(), n_nodes); bool has_cat = get(in).find("split_type") != get(in).cend(); - std::vector split_type; + std::remove_const_t(in["split_type"]))>> + split_type; if (has_cat) { - split_type = get(in["split_type"]); + split_type = get(in["split_type"]); } - stats_.clear(); - nodes_.clear(); + stats = std::remove_reference_t(n_nodes); + nodes = std::remove_reference_t(n_nodes); + split_types = std::remove_reference_t(n_nodes); + split_categories_segments = std::remove_reference_t(n_nodes); - stats_.resize(n_nodes); - nodes_.resize(n_nodes); - split_types_.resize(n_nodes); - split_categories_segments_.resize(n_nodes); + static_assert(std::is_integral(lefts, 0))>::value, ""); + static_assert(std::is_floating_point(loss_changes, 0))>::value, ""); + CHECK_EQ(n_nodes, split_categories_segments.size()); - CHECK_EQ(n_nodes, split_categories_segments_.size()); for (int32_t i = 0; i < n_nodes; ++i) { - auto& s = stats_[i]; - s.loss_chg = get(loss_changes[i]); - s.sum_hess = get(sum_hessian[i]); - s.base_weight = get(base_weights[i]); - - auto& n = nodes_[i]; - bst_node_t left = get(lefts[i]); - bst_node_t right = get(rights[i]); - bst_node_t parent = get(parents[i]); - bst_feature_t ind = get(indices[i]); - float cond { get(conds[i]) }; - bool dft_left { get(default_left[i]) }; - n = Node{left, right, parent, ind, cond, dft_left}; + auto& s = stats[i]; + s.loss_chg = GetElem(loss_changes, i); + s.sum_hess = GetElem(sum_hessian, i); + s.base_weight = GetElem(base_weights, i); + + auto& n = nodes[i]; + bst_node_t left = GetElem(lefts, i); + bst_node_t right = GetElem(rights, i); + bst_node_t parent = GetElem(parents, i); + bst_feature_t ind = GetElem(indices, i); + float cond{GetElem(conds, i)}; + bool dft_left{GetElem(default_left, i)}; + n = RegTree::Node{left, right, parent, ind, cond, dft_left}; if (has_cat) { - split_types_[i] = - static_cast(get(split_type[i])); + split_types[i] = static_cast(GetElem(split_type, i)); } } + return has_cat; +} + +void RegTree::LoadModel(Json const& in) { + bool has_cat{false}; + bool typed = IsA(in["loss_changes"]); + bool feature_is_64 = IsA(in["split_indices"]); + if (typed && feature_is_64) { + has_cat = LoadModelImpl(in, ¶m, &stats_, &split_types_, &nodes_, + &split_categories_segments_); + } else if (typed && !feature_is_64) { + has_cat = LoadModelImpl(in, ¶m, &stats_, &split_types_, &nodes_, + &split_categories_segments_); + } else if (!typed && feature_is_64) { + has_cat = LoadModelImpl(in, ¶m, &stats_, &split_types_, &nodes_, + &split_categories_segments_); + } else { + has_cat = LoadModelImpl(in, ¶m, &stats_, &split_types_, &nodes_, + &split_categories_segments_); + } + if (has_cat) { - this->LoadCategoricalSplit(in); + if (typed) { + this->LoadCategoricalSplit(in); + } else { + this->LoadCategoricalSplit(in); + } } else { this->split_categories_segments_.resize(this->param.num_nodes); std::fill(split_types_.begin(), split_types_.end(), FeatureType::kNumerical); @@ -1058,7 +1128,7 @@ void RegTree::LoadModel(Json const& in) { } // easier access to [] operator auto& self = *this; - for (auto nid = 1; nid < n_nodes; ++nid) { + for (auto nid = 1; nid < param.num_nodes; ++nid) { auto parent = self[nid].Parent(); CHECK_NE(parent, RegTree::kInvalidNodeId); self[nid].SetParent(self[nid].Parent(), self[parent].LeftChild() == nid); @@ -1079,39 +1149,51 @@ void RegTree::SaveModel(Json* p_out) const { CHECK_EQ(param.num_nodes, static_cast(stats_.size())); out["tree_param"] = ToJson(param); CHECK_EQ(get(out["tree_param"]["num_nodes"]), std::to_string(param.num_nodes)); - using I = Integer::Int; auto n_nodes = param.num_nodes; // stats - std::vector loss_changes(n_nodes); - std::vector sum_hessian(n_nodes); - std::vector base_weights(n_nodes); + F32Array loss_changes(n_nodes); + F32Array sum_hessian(n_nodes); + F32Array base_weights(n_nodes); // nodes - std::vector lefts(n_nodes); - std::vector rights(n_nodes); - std::vector parents(n_nodes); - std::vector indices(n_nodes); - std::vector conds(n_nodes); - std::vector default_left(n_nodes); - std::vector split_type(n_nodes); - CHECK_EQ(this->split_types_.size(), param.num_nodes); + I32Array lefts(n_nodes); + I32Array rights(n_nodes); + I32Array parents(n_nodes); - for (bst_node_t i = 0; i < n_nodes; ++i) { - auto const& s = stats_[i]; - loss_changes[i] = s.loss_chg; - sum_hessian[i] = s.sum_hess; - base_weights[i] = s.base_weight; - auto const& n = nodes_[i]; - lefts[i] = static_cast(n.LeftChild()); - rights[i] = static_cast(n.RightChild()); - parents[i] = static_cast(n.Parent()); - indices[i] = static_cast(n.SplitIndex()); - conds[i] = n.SplitCond(); - default_left[i] = n.DefaultLeft(); + F32Array conds(n_nodes); + U8Array default_left(n_nodes); + U8Array split_type(n_nodes); + CHECK_EQ(this->split_types_.size(), param.num_nodes); - split_type[i] = static_cast(this->NodeSplitType(i)); + auto save_tree = [&](auto* p_indices_array) { + auto& indices_array = *p_indices_array; + for (bst_node_t i = 0; i < n_nodes; ++i) { + auto const& s = stats_[i]; + loss_changes.Set(i, s.loss_chg); + sum_hessian.Set(i, s.sum_hess); + base_weights.Set(i, s.base_weight); + + auto const& n = nodes_[i]; + lefts.Set(i, n.LeftChild()); + rights.Set(i, n.RightChild()); + parents.Set(i, n.Parent()); + indices_array.Set(i, n.SplitIndex()); + conds.Set(i, n.SplitCond()); + default_left.Set(i, static_cast(!!n.DefaultLeft())); + + split_type.Set(i, static_cast(this->NodeSplitType(i))); + } + }; + if (this->param.num_feature > static_cast(std::numeric_limits::max())) { + I64Array indices_64(n_nodes); + save_tree(&indices_64); + out["split_indices"] = std::move(indices_64); + } else { + I32Array indices_32(n_nodes); + save_tree(&indices_32); + out["split_indices"] = std::move(indices_32); } this->SaveCategoricalSplit(&out); @@ -1124,7 +1206,7 @@ void RegTree::SaveModel(Json* p_out) const { out["left_children"] = std::move(lefts); out["right_children"] = std::move(rights); out["parents"] = std::move(parents); - out["split_indices"] = std::move(indices); + out["split_conditions"] = std::move(conds); out["default_left"] = std::move(default_left); } diff --git a/tests/ci_build/conda_env/cpu_test.yml b/tests/ci_build/conda_env/cpu_test.yml index 189575a229ef..5a0a64245bab 100644 --- a/tests/ci_build/conda_env/cpu_test.yml +++ b/tests/ci_build/conda_env/cpu_test.yml @@ -32,6 +32,7 @@ dependencies: - awscli - numba - llvmlite +- py-ubjson - pip: - shap - ipython # required by shap at import time. diff --git a/tests/ci_build/conda_env/macos_cpu_test.yml b/tests/ci_build/conda_env/macos_cpu_test.yml index 016fea55501f..7fed0704a4af 100644 --- a/tests/ci_build/conda_env/macos_cpu_test.yml +++ b/tests/ci_build/conda_env/macos_cpu_test.yml @@ -31,6 +31,7 @@ dependencies: - jsonschema - boto3 - awscli +- py-ubjson - pip: - sphinx_rtd_theme - datatable diff --git a/tests/ci_build/conda_env/win64_cpu_test.yml b/tests/ci_build/conda_env/win64_cpu_test.yml index 25378d52f552..a8ac47de7ced 100644 --- a/tests/ci_build/conda_env/win64_cpu_test.yml +++ b/tests/ci_build/conda_env/win64_cpu_test.yml @@ -18,3 +18,4 @@ dependencies: - jsonschema - python-graphviz - pip +- py-ubjson diff --git a/tests/ci_build/conda_env/win64_test.yml b/tests/ci_build/conda_env/win64_test.yml index f8eb30f6a354..bf4274ef3d5b 100644 --- a/tests/ci_build/conda_env/win64_test.yml +++ b/tests/ci_build/conda_env/win64_test.yml @@ -16,3 +16,4 @@ dependencies: - python-graphviz - modin-ray - pip +- py-ubjson diff --git a/tests/cpp/c_api/test_c_api.cc b/tests/cpp/c_api/test_c_api.cc index 319aafba02a2..17110d3f222e 100644 --- a/tests/cpp/c_api/test_c_api.cc +++ b/tests/cpp/c_api/test_c_api.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2019-2020 XGBoost contributors + * Copyright 2019-2022 XGBoost contributors */ #include #include @@ -150,6 +150,33 @@ TEST(CAPI, JsonModelIO) { ASSERT_EQ(model_str_0.front(), '{'); ASSERT_EQ(model_str_0, model_str_1); + + /** + * In memory + */ + bst_ulong len{0}; + char const *data; + XGBoosterSaveModelToBuffer(handle, R"({"format": "ubj"})", &len, &data); + ASSERT_GT(len, 3); + + XGBoosterLoadModelFromBuffer(handle, data, len); + char const *saved; + bst_ulong saved_len{0}; + XGBoosterSaveModelToBuffer(handle, R"({"format": "ubj"})", &saved_len, &saved); + ASSERT_EQ(len, saved_len); + auto l = StringView{data, len}; + auto r = StringView{saved, saved_len}; + ASSERT_EQ(l.size(), r.size()); + ASSERT_EQ(l, r); + + std::string buffer; + Json::Dump(Json::Load(l, std::ios::binary), &buffer); + ASSERT_EQ(model_str_0.size() - 1, buffer.size()); + ASSERT_EQ(model_str_0.back(), '\0'); + ASSERT_TRUE(std::equal(model_str_0.begin(), model_str_0.end() - 1, buffer.begin())); + + ASSERT_EQ(XGBoosterSaveModelToBuffer(handle, R"({})", &len, &data), -1); + ASSERT_EQ(XGBoosterSaveModelToBuffer(handle, R"({"format": "foo"})", &len, &data), -1); } TEST(CAPI, CatchDMLCError) { diff --git a/tests/cpp/test_serialization.cc b/tests/cpp/test_serialization.cc index f971d70a1d65..bf459cf35dcb 100644 --- a/tests/cpp/test_serialization.cc +++ b/tests/cpp/test_serialization.cc @@ -178,8 +178,8 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr learner->Save(&fo); } - Json m_0 = Json::Load(StringView{continued_model.c_str(), continued_model.size()}); - Json m_1 = Json::Load(StringView{model_at_2kiter.c_str(), model_at_2kiter.size()}); + Json m_0 = Json::Load(StringView{continued_model}, std::ios::binary); + Json m_1 = Json::Load(StringView{model_at_2kiter}, std::ios::binary); CompareJSON(m_0, m_1); } @@ -214,8 +214,8 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr common::MemoryBufferStream fo(&serialised_model_tmp); learner->Save(&fo); - Json m_0 = Json::Load(StringView{model_at_2kiter.c_str(), model_at_2kiter.size()}); - Json m_1 = Json::Load(StringView{serialised_model_tmp.c_str(), serialised_model_tmp.size()}); + Json m_0 = Json::Load(StringView{model_at_2kiter}, std::ios::binary); + Json m_1 = Json::Load(StringView{serialised_model_tmp}, std::ios::binary); // GPU ID is changed as data is coming from device. ASSERT_EQ(get(m_0["Config"]["learner"]["generic_param"]).erase("gpu_id"), get(m_1["Config"]["learner"]["generic_param"]).erase("gpu_id")); diff --git a/tests/cpp/tree/test_tree_model.cc b/tests/cpp/tree/test_tree_model.cc index a1e64f542b20..9562b715dd8e 100644 --- a/tests/cpp/tree/test_tree_model.cc +++ b/tests/cpp/tree/test_tree_model.cc @@ -198,8 +198,7 @@ void CheckReload(RegTree const &tree) { Json saved{Object()}; loaded_tree.SaveModel(&saved); - auto same = out == saved; - ASSERT_TRUE(same); + ASSERT_EQ(out, saved); } TEST(Tree, CategoricalIO) { @@ -433,12 +432,12 @@ TEST(Tree, JsonIO) { ASSERT_EQ(get(tparam["num_nodes"]), "3"); ASSERT_EQ(get(tparam["size_leaf_vector"]), "0"); - ASSERT_EQ(get(j_tree["left_children"]).size(), 3ul); - ASSERT_EQ(get(j_tree["right_children"]).size(), 3ul); - ASSERT_EQ(get(j_tree["parents"]).size(), 3ul); - ASSERT_EQ(get(j_tree["split_indices"]).size(), 3ul); - ASSERT_EQ(get(j_tree["split_conditions"]).size(), 3ul); - ASSERT_EQ(get(j_tree["default_left"]).size(), 3ul); + ASSERT_EQ(get(j_tree["left_children"]).size(), 3ul); + ASSERT_EQ(get(j_tree["right_children"]).size(), 3ul); + ASSERT_EQ(get(j_tree["parents"]).size(), 3ul); + ASSERT_EQ(get(j_tree["split_indices"]).size(), 3ul); + ASSERT_EQ(get(j_tree["split_conditions"]).size(), 3ul); + ASSERT_EQ(get(j_tree["default_left"]).size(), 3ul); RegTree loaded_tree; loaded_tree.LoadModel(j_tree); diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index 219a6899a362..d80d9272246d 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -14,7 +14,7 @@ rng = np.random.RandomState(1994) -def json_model(model_path, parameters): +def json_model(model_path: str, parameters: dict) -> dict: X = np.random.random((10, 3)) y = np.random.randint(2, size=(10,)) @@ -22,9 +22,14 @@ def json_model(model_path, parameters): bst = xgb.train(parameters, dm1) bst.save_model(model_path) + if model_path.endswith("ubj"): + import ubjson + with open(model_path, "rb") as ubjfd: + model = ubjson.load(ubjfd) + else: + with open(model_path, 'r') as fd: + model = json.load(fd) - with open(model_path, 'r') as fd: - model = json.load(fd) return model @@ -259,23 +264,40 @@ def test_model_binary_io(self): buf_from_raw = from_raw.save_raw() assert buf == buf_from_raw - def test_model_json_io(self): + def run_model_json_io(self, parameters: dict, ext: str) -> None: + if ext == "ubj" and tm.no_ubjson()["condition"]: + pytest.skip(tm.no_ubjson()["reason"]) + loc = locale.getpreferredencoding(False) - model_path = 'test_model_json_io.json' - parameters = {'tree_method': 'hist', 'booster': 'gbtree'} + model_path = 'test_model_json_io.' + ext j_model = json_model(model_path, parameters) assert isinstance(j_model['learner'], dict) bst = xgb.Booster(model_file=model_path) bst.save_model(fname=model_path) - with open(model_path, 'r') as fd: - j_model = json.load(fd) + if ext == "ubj": + import ubjson + with open(model_path, "rb") as ubjfd: + j_model = ubjson.load(ubjfd) + else: + with open(model_path, 'r') as fd: + j_model = json.load(fd) + assert isinstance(j_model['learner'], dict) os.remove(model_path) assert locale.getpreferredencoding(False) == loc + @pytest.mark.parametrize("ext", ["json", "ubj"]) + def test_model_json_io(self, ext: str) -> None: + parameters = {"booster": "gbtree", "tree_method": "hist"} + self.run_model_json_io(parameters, ext) + parameters = {"booster": "gblinear"} + self.run_model_json_io(parameters, ext) + parameters = {"booster": "dart", "tree_method": "hist"} + self.run_model_json_io(parameters, ext) + @pytest.mark.skipif(**tm.no_json_schema()) def test_json_io_schema(self): import jsonschema diff --git a/tests/python/test_pickling.py b/tests/python/test_pickling.py index 6dc02063221b..37bbc6c13a04 100644 --- a/tests/python/test_pickling.py +++ b/tests/python/test_pickling.py @@ -2,6 +2,7 @@ import numpy as np import xgboost as xgb import os +import json kRows = 100 @@ -15,13 +16,14 @@ def generate_data(): class TestPickling: - def run_model_pickling(self, xgb_params): + def run_model_pickling(self, xgb_params) -> str: X, y = generate_data() dtrain = xgb.DMatrix(X, y) bst = xgb.train(xgb_params, dtrain) dump_0 = bst.get_dump(dump_format='json') assert dump_0 + config_0 = bst.save_config() filename = 'model.pkl' @@ -42,9 +44,22 @@ def run_model_pickling(self, xgb_params): if os.path.exists(filename): os.remove(filename) + config_1 = bst.save_config() + assert config_0 == config_1 + return json.loads(config_0) + def test_model_pickling_json(self): - params = { - 'nthread': 1, - 'tree_method': 'hist', - } - self.run_model_pickling(params) + def check(config): + updater = config["learner"]["gradient_booster"]["updater"] + if params["tree_method"] == "exact": + subsample = updater["grow_colmaker"]["train_param"]["subsample"] + else: + subsample = updater["grow_quantile_histmaker"]["train_param"]["subsample"] + assert float(subsample) == 0.5 + + params = {"nthread": 8, "tree_method": "hist", "subsample": 0.5} + config = self.run_model_pickling(params) + check(config) + params = {"nthread": 8, "tree_method": "exact", "subsample": 0.5} + config = self.run_model_pickling(params) + check(config) diff --git a/tests/python/testing.py b/tests/python/testing.py index cff0e96d56cc..5f8aef124cb0 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -29,6 +29,15 @@ memory = Memory('./cachedir', verbose=0) +def no_ubjson(): + reason = "ubjson is not intsalled." + try: + import ubjson # noqa + return {"condition": False, "reason": reason} + except ImportError: + return {"condition": True, "reason": reason} + + def no_sklearn(): return {'condition': not SKLEARN_INSTALLED, 'reason': 'Scikit-Learn is not installed'}