-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dask] pass additional predict() parameters through when input is a Dask Array #4399
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice catch! Just suggestion for more descriptive variable name below.
What about early stopping for prediction? Is it supported in Dask?
https://lightgbm.readthedocs.io/en/latest/Parameters.html#pred_early_stop
Some examples of corresponding tests for non-Dask estimators:
LightGBM/tests/python_package_test/test_sklearn.py
Lines 593 to 599 in d517ba1
# Tests other parameters for the prediction works | |
res_engine = gbm.predict(X_test) | |
res_sklearn_params = clf.predict_proba(X_test, | |
pred_early_stop=True, | |
pred_early_stop_margin=1.0) | |
with pytest.raises(AssertionError): | |
np.testing.assert_allclose(res_engine, res_sklearn_params) |
LightGBM/tests/python_package_test/test_sklearn.py
Lines 627 to 633 in d517ba1
# Tests other parameters for the prediction works, starting from iteration 10 | |
res_engine = gbm.predict(X_test, start_iteration=10) | |
res_sklearn_params = clf.predict_proba(X_test, | |
pred_early_stop=True, | |
pred_early_stop_margin=1.0, start_iteration=10) | |
with pytest.raises(AssertionError): | |
np.testing.assert_allclose(res_engine, res_sklearn_params) |
LightGBM/tests/python_package_test/test_engine.py
Lines 451 to 462 in c738c83
pred_parameter = {"pred_early_stop": True, | |
"pred_early_stop_freq": 5, | |
"pred_early_stop_margin": 1.5} | |
ret = multi_logloss(y_test, gbm.predict(X_test, **pred_parameter)) | |
assert ret < 0.8 | |
assert ret > 0.6 # loss will be higher than when evaluating the full model | |
pred_parameter = {"pred_early_stop": True, | |
"pred_early_stop_freq": 5, | |
"pred_early_stop_margin": 5.5} | |
ret = multi_logloss(y_test, gbm.predict(X_test, **pred_parameter)) | |
assert ret < 0.2 |
I don't understand what "early stopping for prediction" actually means, can you explain it to me? The parameter descriptions at https://lightgbm.readthedocs.io/en/latest/Parameters.html#pred_early_stop are just short phrases using the same words as the parameter name (e.g. I understand that early stopping for training means "stop the boosting process if performance on a validation set fails to improve", but I don't understand what early stopping means when you're generating predictions. |
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
I guess it is something like "stop accumulating predictions of individual trees in final prediction if individual contributions are becoming insignificant". |
ahhh I see, interesting! Ok I can add calls with prediction early stopping to the tests in these PRs. From #550, the tests you linked in #4399 (review) and my own investigation it seems that |
Added a test with prediction early stopping in dccf44e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
p1_early_stop_raw = dask_classifier.predict( | ||
dX, | ||
pred_early_stop=True, | ||
pred_early_stop_margin=1.0, | ||
pred_early_stop_freq=2, | ||
raw_score=True | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious: why does this particular line not ends with .compute()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just an oversight, there should be a .compute()
. I've opened #4412 to add it.
This pull request has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this. |
.predict()
in the Dask estimators allows users to pass additional prediction parameters (https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters) through**kwargs
. To be applied correctly, those**kwargs
have to be passed through several layers of function calls inside the package.One of those pass-throughs is currently missing, and as a result additional prediction parameters will be silently ignored when data passed to
.predict()
is a Dask Array.This PR proposes fixing that and adding tests confirming that additional parameters are being passed through correctly.
Notes for Reviewers
I looked at
lightgbm.basic._InnerPredictor.predict()
for the names of specific parameters to be passed through.LightGBM/python-package/lightgbm/basic.py
Lines 661 to 663 in bd21efe