Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix loading latest XGBoost binary model. #144

Merged
merged 4 commits into from
Feb 14, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions src/frontend/xgboost.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <dmlc/memory_io.h>
#include <treelite/frontend.h>
#include <treelite/tree.h>
#include <algorithm>
#include <memory>
#include <queue>
#include <cstring>
Expand Down Expand Up @@ -42,6 +43,15 @@ namespace {

typedef float bst_float;

struct ProbToMargin {
static float Sigmoid(float global_bias) {
return -logf(1.0f / global_bias - 1.0f);
}
static float Exponential(float global_bias) {
return logf(global_bias);
}
};

/* peekable input stream implemented with a ring buffer */
class PeekableInputStream {
public:
Expand Down Expand Up @@ -143,8 +153,11 @@ struct LearnerModelParam {
int num_class;
int contain_extra_attrs;
int contain_eval_metrics;
int pad2[29];
uint32_t major_version;
uint32_t minor_version;
int pad2[27];
};
static_assert(sizeof(LearnerModelParam) == 136, "This is the size defined in XGBoost.");

struct GBTreeModelParam {
int num_trees;
Expand Down Expand Up @@ -340,17 +353,9 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) {
CHECK_EQ(fp->Read(&mparam_, sizeof(mparam_)), sizeof(mparam_))
<< "Ill-formed XGBoost model file: corrupted header";
{
// backward compatibility code for compatible with old model type
// for new model, Read(&name_obj_) is suffice
uint64_t len;
CHECK_EQ(fp->Read(&len, sizeof(len)), sizeof(len))
<< "Ill-formed XGBoost model file: corrupted header";
if (len >= std::numeric_limits<unsigned>::max()) {
int gap;
CHECK_EQ(fp->Read(&gap, sizeof(gap)), sizeof(gap))
<< "Ill-formed XGBoost model file: corrupted header";
len = len >> static_cast<uint64_t>(32UL);
}
if (len != 0) {
name_obj_.resize(len);
CHECK_EQ(fp->Read(&name_obj_[0], len), len)
Expand Down Expand Up @@ -382,6 +387,10 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) {
}
CHECK_EQ(gbm_param_.num_roots, 1) << "multi-root trees not supported";

// Before XGBoost 1.0.0, the global bias saved in model is a transformed value. After
// 1.0 it's the original value provided by user.
bool need_transform_to_margin = mparam_.major_version >= 1;

/* 2. Export model */
treelite::Model model;
model.num_feature = mparam_.num_feature;
Expand All @@ -390,6 +399,17 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) {

// set global bias
model.param.global_bias = static_cast<float>(mparam_.base_score);
std::vector<std::string> exponential_family {
"count:poisson", "reg:gamma", "reg:tweedie"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to do something special for survival:cox? or assert that it is not supported?

Copy link
Collaborator

@hcho3 hcho3 Feb 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it is not necessary, since the label is not in log scale for survival:cox. For survival:cox, the convention is to use negative label to represent right-censored data.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hcho3 has better understanding of survival model than me.

};
if (need_transform_to_margin) {
if (name_obj_ == "reg:logistic" || name_obj_ == "binary:logistic") {
model.param.global_bias = ProbToMargin::Sigmoid(model.param.global_bias);
} else if (std::find(exponential_family.cbegin() , exponential_family.cend(), name_obj_)
!= exponential_family.cend()) {
model.param.global_bias = ProbToMargin::Exponential(model.param.global_bias);
}
}

// set correct prediction transform function, depending on objective function
if (name_obj_ == "multi:softmax") {
Expand All @@ -399,8 +419,8 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) {
} else if (name_obj_ == "reg:logistic" || name_obj_ == "binary:logistic") {
model.param.pred_transform = "sigmoid";
model.param.sigmoid_alpha = 1.0f;
} else if (name_obj_ == "count:poisson" || name_obj_ == "reg:gamma"
|| name_obj_ == "reg:tweedie") {
} else if (std::find(exponential_family.cbegin() , exponential_family.cend(), name_obj_)
!= exponential_family.cend()) {
model.param.pred_transform = "exponential";
} else {
model.param.pred_transform = "identity";
Expand Down
25 changes: 25 additions & 0 deletions tests/python/test_xgboost_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,28 @@ def test_xgb_iris(self):
assert predictor.pred_transform == 'max_index'
assert predictor.global_bias == 0.5
assert predictor.sigmoid_alpha == 1.0

def test_logistic(self):
np.random.seed(0)
kRows = 16
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a similar test for the exponential families? You could parametrize and reuse this with just a different randint limit for count:poisson probably.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding it.

kCols = 8
X = np.random.randn(kRows, kCols)
y = np.random.randint(0, 2, size=kRows)
assert y.min() == 0
assert y.max() == 1

dtrain = xgboost.DMatrix(X, y)
booster = xgboost.train({'objective': 'binary:logistic'}, dtrain=dtrain,
num_boost_round=4)
expected_pred = booster.predict(dtrain)
model = treelite.Model.from_xgboost(booster)
libpath = libname('./logistic{}')
batch = treelite.runtime.Batch.from_npy2d(X)
for toolchain in os_compatible_toolchains():
model.export_lib(toolchain=toolchain, libpath=libpath,
params={}, verbose=True)
predictor = treelite.runtime.Predictor(libpath=libpath, verbose=True)
out_pred = predictor.predict(batch)
assert_almost_equal(out_pred, expected_pred)
assert predictor.num_feature == kCols
assert predictor.global_bias == 0