Skip to content

Commit

Permalink
Fix changing locale. (#5314)
Browse files Browse the repository at this point in the history
* Fix changing locale.

* Don't use locale guard.

As number parsing is implemented in house, we don't need locale.

* Update doc.
  • Loading branch information
trivialfis authored Feb 17, 2020
1 parent 7e32af5 commit e433a37
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 25 deletions.
4 changes: 3 additions & 1 deletion doc/tutorials/saving_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ You can load it back to the model generated by same version of XGBoost by:
bst.load_config(config)
This way users can study the internal representation more closely.
This way users can study the internal representation more closely. Please note that some
JSON generators make use of locale dependent floating point serialization methods, which
is not supported by XGBoost.

************
Future Plans
Expand Down
25 changes: 1 addition & 24 deletions src/common/json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
* Copyright (c) by Contributors 2019
*/
#include <cctype>
#include <locale>
#include <sstream>
#include <limits>
#include <cmath>
Expand Down Expand Up @@ -692,47 +693,23 @@ Json JsonReader::ParseBoolean() {
return Json{JsonBoolean{result}};
}

// This is an ad-hoc solution for writing numeric value in standard way. We need to add
// a locale independent way of writing stream like `std::{from, to}_chars' from C++-17.
// FIXME(trivialfis): Remove this.
class GlobalCLocale {
std::locale ori_;

public:
GlobalCLocale() : ori_{std::locale()} {
std::string const name {"C"};
try {
std::locale::global(std::locale(name.c_str()));
} catch (std::runtime_error const& e) {
LOG(FATAL) << "Failed to set locale: " << name;
}
}
~GlobalCLocale() {
std::locale::global(ori_);
}
};

Json Json::Load(StringView str) {
GlobalCLocale guard;
JsonReader reader(str);
Json json{reader.Load()};
return json;
}

Json Json::Load(JsonReader* reader) {
GlobalCLocale guard;
Json json{reader->Load()};
return json;
}

void Json::Dump(Json json, std::ostream *stream, bool pretty) {
GlobalCLocale guard;
JsonWriter writer(stream, pretty);
writer.Save(json);
}

void Json::Dump(Json json, std::string* str, bool pretty) {
GlobalCLocale guard;
std::stringstream ss;
JsonWriter writer(&ss, pretty);
writer.Save(json);
Expand Down
3 changes: 3 additions & 0 deletions tests/python/test_basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import testing as tm
import pytest
import locale

dpath = 'demo/data/'
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
Expand Down Expand Up @@ -300,6 +301,7 @@ def test_model_binary_io(self):
'reg_loss_param']['scale_pos_weight']) == 0.5

def test_model_json_io(self):
loc = locale.getpreferredencoding(False)
model_path = 'test_model_json_io.json'
parameters = {'tree_method': 'hist', 'booster': 'gbtree'}
j_model = json_model(model_path, parameters)
Expand All @@ -313,6 +315,7 @@ def test_model_json_io(self):
assert isinstance(j_model['learner'], dict)

os.remove(model_path)
assert locale.getpreferredencoding(False) == loc

@pytest.mark.skipif(**tm.no_json_schema())
def test_json_schema(self):
Expand Down

0 comments on commit e433a37

Please sign in to comment.