Skip to content

Commit

Permalink
enable xgb_model in scklearn XGBClassifier and test. (#4092)
Browse files Browse the repository at this point in the history
* Enable xgb_model parameter in XGClassifier scikit-learn API

#3049

* add test_XGBClassifier_resume():

test for xgb_model parameter in XGBClassifier API.

* Update test_with_sklearn.py

* Fix lint
  • Loading branch information
tmitanitky authored and hcho3 committed Jan 31, 2019
1 parent 0d0ce32 commit 59f868b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=None,
verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks)

self.objective = xgb_options["objective"]
Expand Down
39 changes: 39 additions & 0 deletions tests/python/test_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,3 +609,42 @@ def test_RFECV():
scale_pos_weight=0.5, silent=True)
rfecv = RFECV(estimator=bst, step=1, cv=3, scoring='neg_log_loss')
rfecv.fit(X, y)


def test_XGBClassifier_resume():
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import log_loss

with TemporaryDirectory() as tempdir:
model1_path = os.path.join(tempdir, 'test_XGBClassifier.model')
model1_booster_path = os.path.join(tempdir, 'test_XGBClassifier.booster')

X, Y = load_breast_cancer(return_X_y=True)

model1 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8)
model1.fit(X, Y)

pred1 = model1.predict(X)
log_loss1 = log_loss(pred1, Y)

# file name of stored xgb model
model1.save_model(model1_path)
model2 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8)
model2.fit(X, Y, xgb_model=model1_path)

pred2 = model2.predict(X)
log_loss2 = log_loss(pred2, Y)

assert np.any(pred1 != pred2)
assert log_loss1 > log_loss2

# file name of 'Booster' instance Xgb model
model1.get_booster().save_model(model1_booster_path)
model2 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8)
model2.fit(X, Y, xgb_model=model1_booster_path)

pred2 = model2.predict(X)
log_loss2 = log_loss(pred2, Y)

assert np.any(pred1 != pred2)
assert log_loss1 > log_loss2

0 comments on commit 59f868b

Please sign in to comment.