Skip to content

Commit

Permalink
Fix feature names and types in output model slice.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 4, 2021
1 parent b56d3d5 commit f341958
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,8 @@ class LearnerImpl : public LearnerIO {
out_impl->mparam_ = this->mparam_;
out_impl->attributes_ = this->attributes_;
out_impl->learner_model_param_ = this->learner_model_param_;
out_impl->SetFeatureNames(this->feature_names_);
out_impl->SetFeatureTypes(this->feature_types_);
out_impl->LoadConfig(config);
out_impl->Configure();
return out_impl;
Expand Down
3 changes: 3 additions & 0 deletions tests/python/test_basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,10 +379,13 @@ def test_slice(self, booster):
'num_parallel_tree': 4, 'subsample': 0.5, 'num_class': 3, 'booster': booster,
'objective': 'multi:softprob'},
num_boost_round=num_boost_round, dtrain=dtrain)
booster.feature_types = ["q"] * X.shape[1]

assert len(booster.get_dump()) == total_trees
beg = 3
end = 7
sliced: xgb.Booster = booster[beg: end]
assert sliced.feature_types == booster.feature_types

sliced_trees = (end - beg) * num_parallel_tree * num_classes
assert sliced_trees == len(sliced.get_dump())
Expand Down

0 comments on commit f341958

Please sign in to comment.