Skip to content

Commit

Permalink
Fix prediction configuration. (#7159)
Browse files Browse the repository at this point in the history
After the predictor parameter was added to the constructor, this configuration was broken.
  • Loading branch information
trivialfis authored Aug 11, 2021
1 parent 9600ca8 commit 3f38d98
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,8 +798,8 @@ def _can_use_inplace_predict(self) -> bool:
# error with incompatible data type.
# Inplace predict doesn't handle as many data types as DMatrix, but it's
# sufficient for dask interface where input is simpiler.
params = self.get_params()
if params.get("predictor", None) is None and self.booster != "gblinear":
predictor = self.get_params().get("predictor", None)
if predictor in ("auto", None) and self.booster != "gblinear":
return True
return False

Expand Down
17 changes: 17 additions & 0 deletions tests/python/test_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,3 +1254,20 @@ def test_estimator_reg(estimator, check):
estimator.set_params(**xgb.XGBRegressor().fit(X, y).get_params())

check(estimator)


def test_prediction_config():
reg = xgb.XGBRegressor()
assert reg._can_use_inplace_predict() is True

reg.set_params(predictor="cpu_predictor")
assert reg._can_use_inplace_predict() is False

reg.set_params(predictor="auto")
assert reg._can_use_inplace_predict() is True

reg.set_params(predictor=None)
assert reg._can_use_inplace_predict() is True

reg.set_params(booster="gblinear")
assert reg._can_use_inplace_predict() is False

0 comments on commit 3f38d98

Please sign in to comment.