Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] pass additional predict() parameters through when input is a Dask Array #4399

Merged
merged 5 commits into from
Jun 26, 2021

Conversation

jameslamb
Copy link
Collaborator

.predict() in the Dask estimators allows users to pass additional prediction parameters (https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters) through **kwargs. To be applied correctly, those **kwargs have to be passed through several layers of function calls inside the package.

One of those pass-throughs is currently missing, and as a result additional prediction parameters will be silently ignored when data passed to .predict() is a Dask Array.

This PR proposes fixing that and adding tests confirming that additional parameters are being passed through correctly.

Notes for Reviewers

I looked at lightgbm.basic._InnerPredictor.predict() for the names of specific parameters to be passed through.

def predict(self, data, start_iteration=0, num_iteration=-1,
raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False,
is_reshape=True):

@jameslamb jameslamb added the fix label Jun 23, 2021
@jameslamb jameslamb requested a review from StrikerRUS June 23, 2021 04:16
@jameslamb jameslamb changed the title [dask] pass predict() kwargs through when input is a Dask Array [dask] pass additional predict() parameters through when input is a Dask Array Jun 23, 2021
Copy link
Collaborator

@StrikerRUS StrikerRUS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice catch! Just suggestion for more descriptive variable name below.
What about early stopping for prediction? Is it supported in Dask?
https://lightgbm.readthedocs.io/en/latest/Parameters.html#pred_early_stop

Some examples of corresponding tests for non-Dask estimators:

# Tests other parameters for the prediction works
res_engine = gbm.predict(X_test)
res_sklearn_params = clf.predict_proba(X_test,
pred_early_stop=True,
pred_early_stop_margin=1.0)
with pytest.raises(AssertionError):
np.testing.assert_allclose(res_engine, res_sklearn_params)

# Tests other parameters for the prediction works, starting from iteration 10
res_engine = gbm.predict(X_test, start_iteration=10)
res_sklearn_params = clf.predict_proba(X_test,
pred_early_stop=True,
pred_early_stop_margin=1.0, start_iteration=10)
with pytest.raises(AssertionError):
np.testing.assert_allclose(res_engine, res_sklearn_params)

pred_parameter = {"pred_early_stop": True,
"pred_early_stop_freq": 5,
"pred_early_stop_margin": 1.5}
ret = multi_logloss(y_test, gbm.predict(X_test, **pred_parameter))
assert ret < 0.8
assert ret > 0.6 # loss will be higher than when evaluating the full model
pred_parameter = {"pred_early_stop": True,
"pred_early_stop_freq": 5,
"pred_early_stop_margin": 5.5}
ret = multi_logloss(y_test, gbm.predict(X_test, **pred_parameter))
assert ret < 0.2

tests/python_package_test/test_dask.py Outdated Show resolved Hide resolved
tests/python_package_test/test_dask.py Outdated Show resolved Hide resolved
tests/python_package_test/test_dask.py Outdated Show resolved Hide resolved
tests/python_package_test/test_dask.py Outdated Show resolved Hide resolved
tests/python_package_test/test_dask.py Outdated Show resolved Hide resolved
tests/python_package_test/test_dask.py Outdated Show resolved Hide resolved
@jameslamb
Copy link
Collaborator Author

What about early stopping for prediction? Is it supported in Dask?

I don't understand what "early stopping for prediction" actually means, can you explain it to me? The parameter descriptions at https://lightgbm.readthedocs.io/en/latest/Parameters.html#pred_early_stop are just short phrases using the same words as the parameter name (e.g. pred_early_stop_margin = "the threshold of margin in early-stopping prediction"), and I don't understand from the unit tests linked to in #4399 (review) what that functionality actually does.

I understand that early stopping for training means "stop the boosting process if performance on a validation set fails to improve", but I don't understand what early stopping means when you're generating predictions.

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
@StrikerRUS
Copy link
Collaborator

I guess it is something like "stop accumulating predictions of individual trees in final prediction if individual contributions are becoming insignificant".
Here is explanation from original author I just found: #565 (comment).
Original PR: #550.

@jameslamb
Copy link
Collaborator Author

jameslamb commented Jun 24, 2021

ahhh I see, interesting! Ok I can add calls with prediction early stopping to the tests in these PRs.

From #550, the tests you linked in #4399 (review) and my own investigation it seems that that parameter those parameters will only have an effect for classification objectives, so I'll only add it to the classifier tests.

@jameslamb
Copy link
Collaborator Author

Added a test with prediction early stopping in dccf44e

Copy link
Collaborator

@StrikerRUS StrikerRUS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

Comment on lines +277 to +283
p1_early_stop_raw = dask_classifier.predict(
dX,
pred_early_stop=True,
pred_early_stop_margin=1.0,
pred_early_stop_freq=2,
raw_score=True
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: why does this particular line not ends with .compute()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just an oversight, there should be a .compute(). I've opened #4412 to add it.

@StrikerRUS StrikerRUS merged commit 8116d88 into master Jun 26, 2021
@StrikerRUS StrikerRUS deleted the fix/dask-predict-kwargs branch June 26, 2021 13:01
@github-actions
Copy link

This pull request has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this.

@github-actions github-actions bot locked as resolved and limited conversation to collaborators Aug 23, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants