Skip to content

Commit

Permalink
Fix handling of global bias for binary:logitraw objective of XGBoost (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 authored Dec 29, 2020
1 parent 191a474 commit b6c9d39
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/frontend/xgboost.cc
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ inline std::unique_ptr<treelite::Model> ParseStream(dmlc::Stream* fi) {
// 1.0 it's the original value provided by user.
const bool need_transform_to_margin = mparam_.major_version >= 1;
if (need_transform_to_margin) {
treelite::details::xgboost::TransformGlobalBiasToMargin(name_obj_, &model->param);
treelite::details::xgboost::TransformGlobalBiasToMargin(&model->param);
}

// traverse trees
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/xgboost/xgboost.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ extern const std::vector<std::string> exponential_objectives;
void SetPredTransform(const std::string& objective_name, ModelParam* param);

// Transform the global bias parameter from probability into margin score
void TransformGlobalBiasToMargin(const std::string& objective_name, ModelParam* param);
void TransformGlobalBiasToMargin(ModelParam* param);

enum FeatureType {
kNumerical = 0,
Expand Down
3 changes: 1 addition & 2 deletions src/frontend/xgboost_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,7 @@ bool XGBoostModelHandler::EndObject(std::size_t memberCount) {
// 1.0 it's the original value provided by user.
const bool need_transform_to_margin = (version[0] >= 1);
if (need_transform_to_margin) {
treelite::details::xgboost::TransformGlobalBiasToMargin(
output.objective_name, &output.model->param);
treelite::details::xgboost::TransformGlobalBiasToMargin(&output.model->param);
}
return pop_handler();
}
Expand Down
9 changes: 1 addition & 8 deletions src/frontend/xgboost_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,8 @@ void SetPredTransform(const std::string& objective_name, ModelParam* param) {
}

// Transform the global bias parameter from probability into margin score
void TransformGlobalBiasToMargin(const std::string& objective_name, ModelParam* param) {
void TransformGlobalBiasToMargin(ModelParam* param) {
std::string bias_transform{param->pred_transform};
if (objective_name == "binary:logitraw") {
// Special handling for 'logitraw', where the global bias is transformed with 'sigmoid',
// but the prediction is returned un-transformed.
CHECK_EQ(bias_transform, "identity");
bias_transform = "sigmoid";
}

if (bias_transform == "sigmoid") {
param->global_bias = ProbToMargin::Sigmoid(param->global_bias);
} else if (bias_transform == "exponential") {
Expand Down
4 changes: 2 additions & 2 deletions tests/python/test_xgboost_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,18 @@ def test_xgb_iris(tmpdir, toolchain, objective, model_format, expected_pred_tran
np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5)


@pytest.mark.parametrize('toolchain', os_compatible_toolchains())
@pytest.mark.parametrize('model_format', ['binary', 'json'])
@pytest.mark.parametrize('objective,max_label,expected_global_bias',
[('binary:logistic', 2, 0),
('binary:hinge', 2, 0.5),
('binary:logitraw', 2, 0),
('binary:logitraw', 2, 0.5),
('count:poisson', 4, math.log(0.5)),
('rank:pairwise', 5, 0.5),
('rank:ndcg', 5, 0.5),
('rank:map', 5, 0.5)],
ids=['binary:logistic', 'binary:hinge', 'binary:logitraw',
'count:poisson', 'rank:pairwise', 'rank:ndcg', 'rank:map'])
@pytest.mark.parametrize('toolchain', os_compatible_toolchains())
def test_nonlinear_objective(tmpdir, objective, max_label, expected_global_bias, toolchain,
model_format):
# pylint: disable=too-many-locals,too-many-arguments
Expand Down

0 comments on commit b6c9d39

Please sign in to comment.