Skip to content

Commit

Permalink
Cleaner code.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 10, 2022
1 parent f35e1ac commit 4ad62f3
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -729,8 +729,7 @@ class LearnerIO : public LearnerConfiguration {
std::string const serialisation_header_ { u8"CONFIG-offset:" };

public:
explicit LearnerIO(std::vector<std::shared_ptr<DMatrix> > cache) :
LearnerConfiguration{cache} {}
explicit LearnerIO(std::vector<std::shared_ptr<DMatrix>> cache) : LearnerConfiguration{cache} {}

void LoadModel(Json const& in) override {
CHECK(IsA<Object>(in));
Expand All @@ -754,8 +753,8 @@ class LearnerIO : public LearnerConfiguration {
auto const& gradient_booster = learner.at("gradient_booster");
name = get<String>(gradient_booster["name"]);
tparam_.UpdateAllowUnknown(Args{{"booster", name}});
gbm_.reset(GradientBooster::Create(tparam_.booster,
&generic_parameters_, &learner_model_param_));
gbm_.reset(
GradientBooster::Create(tparam_.booster, &generic_parameters_, &learner_model_param_));
gbm_->LoadModel(gradient_booster);

auto const& j_attributes = get<Object const>(learner.at("attributes"));
Expand All @@ -767,20 +766,17 @@ class LearnerIO : public LearnerConfiguration {
// feature names and types are saved in xgboost 1.4
auto it = learner.find("feature_names");
if (it != learner.cend()) {
auto const &feature_names = get<Array const>(it->second);
feature_names_.clear();
for (auto const &name : feature_names) {
feature_names_.emplace_back(get<String const>(name));
}
auto const& feature_names = get<Array const>(it->second);
feature_names_.resize(feature_names.size());
std::transform(feature_names.cbegin(), feature_names.cend(), feature_names_.begin(),
[](Json const& fn) { return get<String const>(fn); });
}
it = learner.find("feature_types");
if (it != learner.cend()) {
auto const &feature_types = get<Array const>(it->second);
feature_types_.clear();
for (auto const &name : feature_types) {
auto type = get<String const>(name);
feature_types_.emplace_back(type);
}
auto const& feature_types = get<Array const>(it->second);
feature_types_.resize(feature_types.size());
std::transform(feature_types.cbegin(), feature_types.cend(), feature_types_.begin(),
[](Json const& fn) { return get<String const>(fn); });
}

this->need_configuration_ = true;
Expand Down

0 comments on commit 4ad62f3

Please sign in to comment.