Skip to content

Commit

Permalink
Enforce tree order in JSON. (dmlc#5974)
Browse files Browse the repository at this point in the history
* Make JSON model IO more future proof by using tree id in model loading.
  • Loading branch information
trivialfis committed Aug 11, 2020
1 parent b13fcfe commit fc06538
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 12 deletions.
24 changes: 13 additions & 11 deletions src/gbm/gbtree_model.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
/*!
* Copyright 2019 by Contributors
* Copyright 2019-2020 by Contributors
*/
#include <utility>

#include "xgboost/json.h"
#include "xgboost/logging.h"
#include "gbtree_model.h"
Expand Down Expand Up @@ -41,15 +43,14 @@ void GBTreeModel::SaveModel(Json* p_out) const {
auto& out = *p_out;
CHECK_EQ(param.num_trees, static_cast<int>(trees.size()));
out["gbtree_model_param"] = ToJson(param);
std::vector<Json> trees_json;
size_t t = 0;
for (auto const& tree : trees) {
std::vector<Json> trees_json(trees.size());

for (size_t t = 0; t < trees.size(); ++t) {
auto const& tree = trees[t];
Json tree_json{Object()};
tree->SaveModel(&tree_json);
// The field is not used in XGBoost, but might be useful for external project.
tree_json["id"] = Integer(t);
trees_json.emplace_back(tree_json);
t++;
tree_json["id"] = Integer(static_cast<Integer::Int>(t));
trees_json[t] = std::move(tree_json);
}

std::vector<Json> tree_info_json(tree_info.size());
Expand All @@ -70,9 +71,10 @@ void GBTreeModel::LoadModel(Json const& in) {
auto const& trees_json = get<Array const>(in["trees"]);
trees.resize(trees_json.size());

for (size_t t = 0; t < trees.size(); ++t) {
trees[t].reset( new RegTree() );
trees[t]->LoadModel(trees_json[t]);
for (size_t t = 0; t < trees_json.size(); ++t) { // NOLINT
auto tree_id = get<Integer>(trees_json[t]["id"]);
trees.at(tree_id).reset(new RegTree());
trees.at(tree_id)->LoadModel(trees_json[t]);
}

tree_info.resize(param.num_trees);
Expand Down
11 changes: 10 additions & 1 deletion tests/cpp/test_learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,16 @@ TEST(Learner, JsonModelIO) {
Json out { Object() };
learner->SaveModel(&out);

learner->LoadModel(out);
dmlc::TemporaryDirectory tmpdir;

std::ofstream fout (tmpdir.path + "/model.json");
fout << out;
fout.close();

auto loaded_str = common::LoadSequentialFile(tmpdir.path + "/model.json");
Json loaded = Json::Load(StringView{loaded_str.c_str(), loaded_str.size()});

learner->LoadModel(loaded);
learner->Configure();

Json new_in { Object() };
Expand Down

0 comments on commit fc06538

Please sign in to comment.