Skip to content

Commit

Permalink
move extra_attr to tparam
Browse files Browse the repository at this point in the history
  • Loading branch information
Chen Qin committed Aug 26, 2019
1 parent eec98fa commit 9134b63
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ struct LearnerTrainParam : public dmlc::Parameter<LearnerTrainParam> {
DataSplitMode dsplit;
// flag to disable default metric
int disable_default_eval_metric;
// flag to enable rabit cache checkpoint
int rabit_cache_version;

std::string booster;
std::string objective;
Expand All @@ -121,6 +123,9 @@ struct LearnerTrainParam : public dmlc::Parameter<LearnerTrainParam> {
DMLC_DECLARE_FIELD(disable_default_eval_metric)
.set_default(0)
.describe("flag to disable default metric. Set to >0 to disable");
DMLC_DECLARE_FIELD(rabit_cache_version)
.set_default(0)
.describe("flag to enable rabit cache checkpoint. Set to >0 to disable");
DMLC_DECLARE_FIELD(booster)
.set_default("gbtree")
.describe("Gradient booster used for training.");
Expand Down Expand Up @@ -191,8 +196,6 @@ class LearnerImpl : public Learner {
}

void Load(dmlc::Stream* fi) override {
bool is_09x_format = false;

generic_param_.InitAllowUnknown(Args{});
tparam_.Init(std::vector<std::pair<std::string, std::string>>{});
// TODO(tqchen) mark deprecation of old format.
Expand Down Expand Up @@ -272,10 +275,15 @@ class LearnerImpl : public Learner {
kv.second = "cpu_predictor";
LOG(INFO) << "Switch gpu_predictor to cpu_predictor.";
}
if (saved_param == "max_depth") cfg_[saved_param] = kv.second;
if (saved_param == "tree_method") cfg_[saved_param] = kv.second;
if (saved_param == "max_depth") {
cfg_[saved_param] = kv.second;
tparam_.rabit_cache_version = 1;
}
if (saved_param == "tree_method") {
cfg_[saved_param] = kv.second;
tparam_.rabit_cache_version = 1;
}
}
if (kv.first.find("model_format") == 0) is_09x_format = true;
}
attributes_ = std::map<std::string, std::string>(attr.begin(), attr.end());
}
Expand Down Expand Up @@ -307,7 +315,10 @@ class LearnerImpl : public Learner {
p_metric->Configure({cfg_.begin(), cfg_.end()});
}

if (is_09x_format) {
if (tparam_.rabit_cache_version != 0) {
CHECK_EQ(fi->Read(&tparam_.rabit_cache_version, sizeof(tparam_.rabit_cache_version)),
sizeof(tparam_.rabit_cache_version))
<< "BoostLearner: wrong version format";
// inference tests show it runs mulitple rabit init without shutdown
// we need to guard against inference also call this with empty payload
CHECK_EQ(fi->Read(&tparam_.dsplit, sizeof(tparam_.dsplit)), sizeof(tparam_.dsplit))
Expand Down Expand Up @@ -339,9 +350,6 @@ class LearnerImpl : public Learner {
LearnerModelParam mparam = mparam_; // make a copy to potentially modify
std::vector<std::pair<std::string, std::string> > extra_attr;

mparam.contain_extra_attrs = 1;
extra_attr.emplace_back("model_format", "0.9x");

// extra attributed to be added just before saving
if (tparam_.objective == "count:poisson") {
auto it = cfg_.find("max_delta_step");
Expand Down Expand Up @@ -393,6 +401,7 @@ class LearnerImpl : public Learner {
}
fo->Write(metr);
}
fo->Write(&tparam_.rabit_cache_version, sizeof(int));
fo->Write(&tparam_.dsplit, sizeof(DataSplitMode));
fo->Write(&tparam_.disable_default_eval_metric, sizeof(int));
}
Expand Down

0 comments on commit 9134b63

Please sign in to comment.