Skip to content

Commit

Permalink
Fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 11, 2022
1 parent 16fb4c5 commit b90b635
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
3 changes: 2 additions & 1 deletion jvm-packages/xgboost4j/src/native/xgboost4j.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -670,13 +670,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModel
BoosterHandle handle = (BoosterHandle) jhandle;
const char *format = jenv->GetStringUTFChars(jformat, 0);
bst_ulong len = 0;
const char* result;
const char *result{nullptr};
xgboost::Json config {xgboost::Object{}};
config["format"] = std::string{format};
std::string config_str;
xgboost::Json::Dump(config, &config_str);

int ret = XGBoosterSaveModelToBuffer(handle, config_str.c_str(), &len, &result);
JVM_CHECK_CALL(ret);
if (result) {
jbyteArray jarray = jenv->NewByteArray(len);
jenv->SetByteArrayRegion(jarray, 0, len, (jbyte *)result);
Expand Down
11 changes: 5 additions & 6 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
auto config = Json::Load(StringView{c_json_config});
auto missing = GetMissing(config);
std::string cache = RequiredArg<String>(config, "cache_prefix", __func__);
auto n_threads = OptionalArg<Integer>(config, "nthread", common::OmpGetNumThreads(0));
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
*out = new std::shared_ptr<xgboost::DMatrix>{
xgboost::DMatrix::Create(iter, proxy, reset, next, missing, n_threads, cache)};
API_END();
Expand Down Expand Up @@ -352,7 +352,7 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr,
StringView{data}, ncol);
auto config = Json::Load(StringView{c_json_config});
float missing = GetMissing(config);
auto n_threads = OptionalArg<Integer>(config, "nthread", common::OmpGetNumThreads(0));
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
API_END();
}
Expand All @@ -365,7 +365,7 @@ XGB_DLL int XGDMatrixCreateFromDense(char const *data,
xgboost::data::ArrayAdapter(StringView{data})};
auto config = Json::Load(StringView{c_json_config});
float missing = GetMissing(config);
auto n_threads = OptionalArg<Integer>(config, "nthread", common::OmpGetNumThreads(0));
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
*out =
new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
API_END();
Expand Down Expand Up @@ -919,9 +919,8 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
namespace {
void WarnOldModel() {
if (XGBOOST_VER_MAJOR >= 2) {
LOG(WARNING)
<< "Found deprecated binary model format, please add `.json` or `.ubj` as file extension. "
"Model format will default to JSON in XGBoost 2.2 when file extension is not specified.";
LOG(WARNING) << "Saving into deprecated binary model format, please consider using `json` or "
"`ubj`. Model format will default to JSON in XGBoost 2.2 if not specified.";
}
}
} // anonymous namespace
Expand Down
6 changes: 3 additions & 3 deletions src/c_api/c_api_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ inline void GenerateFeatureMap(Learner const *learner,
void XGBBuildInfoDevice(Json* p_info);

template <typename JT>
inline auto RequiredArg(Json in, std::string const &key, StringView func) {
auto const &RequiredArg(Json const &in, std::string const &key, StringView func) {
auto const &obj = get<Object const>(in);
auto it = obj.find(key);
if (it == obj.cend() || IsA<Null>(it->second)) {
Expand All @@ -253,11 +253,11 @@ inline auto RequiredArg(Json in, std::string const &key, StringView func) {
}

template <typename JT, typename T>
inline auto OptionalArg(Json in, std::string const &key, T const &dft) {
auto const &OptionalArg(Json const &in, std::string const &key, T const &dft) {
auto const &obj = get<Object const>(in);
auto it = obj.find(key);
if (it != obj.cend()) {
return static_cast<T>(get<std::remove_const_t<JT> const>(it->second));
return get<std::remove_const_t<JT> const>(it->second);
}
return dft;
}
Expand Down

0 comments on commit b90b635

Please sign in to comment.