From 4ad62f30b24a3556a200eca0089889dfb6104b01 Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 11 Jan 2022 00:16:32 +0800 Subject: [PATCH] Cleaner code. --- src/learner.cc | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/src/learner.cc b/src/learner.cc index d5e891c738e9..24932b9279f2 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -729,8 +729,7 @@ class LearnerIO : public LearnerConfiguration { std::string const serialisation_header_ { u8"CONFIG-offset:" }; public: - explicit LearnerIO(std::vector > cache) : - LearnerConfiguration{cache} {} + explicit LearnerIO(std::vector> cache) : LearnerConfiguration{cache} {} void LoadModel(Json const& in) override { CHECK(IsA(in)); @@ -754,8 +753,8 @@ class LearnerIO : public LearnerConfiguration { auto const& gradient_booster = learner.at("gradient_booster"); name = get(gradient_booster["name"]); tparam_.UpdateAllowUnknown(Args{{"booster", name}}); - gbm_.reset(GradientBooster::Create(tparam_.booster, - &generic_parameters_, &learner_model_param_)); + gbm_.reset( + GradientBooster::Create(tparam_.booster, &generic_parameters_, &learner_model_param_)); gbm_->LoadModel(gradient_booster); auto const& j_attributes = get(learner.at("attributes")); @@ -767,20 +766,17 @@ class LearnerIO : public LearnerConfiguration { // feature names and types are saved in xgboost 1.4 auto it = learner.find("feature_names"); if (it != learner.cend()) { - auto const &feature_names = get(it->second); - feature_names_.clear(); - for (auto const &name : feature_names) { - feature_names_.emplace_back(get(name)); - } + auto const& feature_names = get(it->second); + feature_names_.resize(feature_names.size()); + std::transform(feature_names.cbegin(), feature_names.cend(), feature_names_.begin(), + [](Json const& fn) { return get(fn); }); } it = learner.find("feature_types"); if (it != learner.cend()) { - auto const &feature_types = get(it->second); - feature_types_.clear(); - for (auto const &name : feature_types) { - auto type = get(name); - feature_types_.emplace_back(type); - } + auto const& feature_types = get(it->second); + feature_types_.resize(feature_types.size()); + std::transform(feature_types.cbegin(), feature_types.cend(), feature_types_.begin(), + [](Json const& fn) { return get(fn); }); } this->need_configuration_ = true;