Skip to content

Commit

Permalink
IO test.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 7, 2021
1 parent 0f6f8ea commit 868e2bb
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 13 deletions.
27 changes: 19 additions & 8 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
XGBClassifierBase, XGBRegressorBase, XGBoostLabelEncoder)


class XGBRankerMixIn:
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn base
classes."""
_estimator_type = "ranker"


def _objective_decorator(func):
"""Decorate an objective function
Expand Down Expand Up @@ -298,6 +304,9 @@ def _more_tags(self):
'''Tags used for scikit-learn data validation.'''
return {'allow_nan': True, 'no_validation': True}

def _model_type(self):
raise NotImplementedError("Base model doesn't have model type.")

def get_booster(self):
"""Get the underlying xgboost Booster of this model.
Expand Down Expand Up @@ -442,7 +451,6 @@ def save_model(self, fname: str):
meta[k] = v
except TypeError:
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.')
meta['type'] = type(self).__name__
meta_str = json.dumps(meta)
self.get_booster().set_attr(scikit_learn=meta_str)
self.get_booster().save_model(fname)
Expand Down Expand Up @@ -484,12 +492,12 @@ def load_model(self, fname):
if k == 'use_label_encoder':
self.use_label_encoder = bool(v)
continue
if k == 'type' and type(self).__name__ != v:
msg = 'Current model type: {}, '.format(type(self).__name__) + \
'type of model in file: {}'.format(v)
raise TypeError(msg)
if k == 'type':
continue
if k == "_estimator_type":
if self._estimator_type != v:
raise TypeError(
"Loading an estimator with different type "
f"{self._estimator_type}, {v}"
)
states[k] = v
self.__dict__.update(states)
# Delete the attribute after load
Expand Down Expand Up @@ -849,6 +857,9 @@ def __init__(self, *, objective="binary:logistic", use_label_encoder=True, **kwa
self.use_label_encoder = use_label_encoder
super().__init__(objective=objective, **kwargs)

def _model_type(self) -> str:
return "cls"

@_deprecate_positional_args
def fit(self, X, y, *, sample_weight=None, base_margin=None,
eval_set=None, eval_metric=None,
Expand Down Expand Up @@ -1211,7 +1222,7 @@ def get_num_boosting_rounds(self):
then your group array should be ``[3, 4]``.
''')
class XGBRanker(XGBModel):
class XGBRanker(XGBModel, XGBRankerMixIn):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
@_deprecate_positional_args
def __init__(self, *, objective='rank:pairwise', **kwargs):
Expand Down
42 changes: 37 additions & 5 deletions tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ def test_dask_regressor() -> None:
with Client(cluster) as client:
X, y, w = generate_array(with_weights=True)
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
assert regressor._estimator_type == "regressor"

regressor.set_params(tree_method='hist')
regressor.client = client
regressor.fit(X, y, sample_weight=w, eval_set=[(X, y)])
Expand All @@ -285,6 +287,8 @@ def test_dask_classifier() -> None:
y = (y * 10).astype(np.int32)
classifier = xgb.dask.DaskXGBClassifier(
verbosity=1, n_estimators=2, eval_metric='merror')
assert classifier._estimator_type == "classifier"

classifier.client = client
classifier.fit(X, y, sample_weight=w, eval_set=[(X, y)])
prediction = classifier.predict(X)
Expand Down Expand Up @@ -960,7 +964,7 @@ def run_shap(self, X: Any, y: Any, params: Dict[str, Any], client: "Client") ->

def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None:
X, y = da.from_array(X), da.from_array(y)
cls = xgb.dask.DaskXGBClassifier()
cls = xgb.dask.DaskXGBClassifier(n_estimators=10)
cls.client = client
cls.fit(X, y)
booster = cls.get_booster()
Expand All @@ -971,6 +975,7 @@ def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None:
margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute()
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)

@pytest.mark.skipif(**tm.no_sklearn())
def test_shap(self, client: "Client") -> None:
from sklearn.datasets import load_boston, load_digits
X, y = load_boston(return_X_y=True)
Expand Down Expand Up @@ -1007,14 +1012,41 @@ def run_shap_interactions(
margin,
1e-5, 1e-5)

@pytest.mark.skipif(**tm.no_sklearn())
def test_shap_interactions(self, client: "Client") -> None:
from sklearn.datasets import load_boston, load_digits
from sklearn.datasets import load_boston
X, y = load_boston(return_X_y=True)
params = {'objective': 'reg:squarederror'}
self.run_shap_interactions(X, y, params, client)
X, y = load_digits(return_X_y=True)
params = {'objective': 'multi:softprob', 'num_class': 10}
self.run_shap_interactions(X, y, params, client)

@pytest.mark.skipif(**tm.no_sklearn())
def test_sklearn_io(self, client: 'Client') -> None:
from sklearn.datasets import load_digits
X_, y_ = load_digits(return_X_y=True)
X, y = da.from_array(X_), da.from_array(y_)
cls = xgb.dask.DaskXGBClassifier(n_estimators=10)
cls.client = client
cls.fit(X, y)
predt_0 = cls.predict(X)

with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'cls.json')
cls.save_model(path)

cls = xgb.dask.DaskXGBClassifier()
cls.load_model(path)
assert cls.n_classes_ == 10
predt_1 = cls.predict(X)

np.testing.assert_allclose(predt_0.compute(), predt_1.compute())

# Use single node to load
cls = xgb.XGBClassifier()
cls.load_model(path)
assert cls.n_classes_ == 10
predt_2 = cls.predict(X_)

np.testing.assert_allclose(predt_0.compute(), predt_2)


class TestDaskCallbacks:
Expand Down
8 changes: 8 additions & 0 deletions tests/python/test_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,3 +1099,11 @@ def test_boost_from_prediction_approx():
@pytest.mark.skipif(**tm.no_sklearn())
def test_boost_from_prediction_exact():
run_boost_from_prediction('exact')


def test_estimator_type():
assert xgb.XGBClassifier._estimator_type == "classifier"
assert xgb.XGBRFClassifier._estimator_type == "classifier"
assert xgb.XGBRegressor._estimator_type == "regressor"
assert xgb.XGBRFRegressor._estimator_type == "regressor"
assert xgb.XGBRanker._estimator_type == "ranker"

0 comments on commit 868e2bb

Please sign in to comment.