-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[python] raise an informative error instead of segfaulting when custom objective produces incorrect output #4815
Conversation
@yaxxie Thanks a lot for this PR!
This assumption is wrong. Also, what do you think about moving this check to cpp side LightGBM/src/boosting/goss.hpp Line 59 in d517ba1
LightGBM/src/boosting/gbdt.cpp Line 371 in b0137de
so that all language packages will benefit from it? |
Thanks @StrikerRUS I figured the assumption would be wrong when I saw some failing tests If we move this to the CPP side, it would require an API change, right? |
Sorry, could you please clarify what API changes do you mean? I believe something like int64_t total_size = static_cast<int64_t>(num_data_) * num_tree_per_iteration_;
CHECK_EQ(hessians.size(), hessians.size());
CHECK_EQ(hessians.size(), total_size); would be enough. |
Happy to try this out, but I can't see how it would work. When the data array comes from outside the lib, how do we know its size? |
Sorry, didn't get your question. |
The The other C-API points will pass size variables where they need to know them (such as when we construct from mat and we need a size variable there, or when we retrieve string names and we need to know how much was allocated to the underlying buffer so as not to overwrite past allocated memory). This is why the API can segfault when passed incorrectly sized gradients, since, an underlying assumption is made about the size of the allocated buffers which is not (and cannot) be verified by the library itself. We could have the user pass the sizes to the API call, but this would then still require changes at all language implementations. I'm happy for someone to point out where I've gone wrong with the above. (I tried the code sample provided, it won't compile because float*'s don't have
|
@yaxxie Thanks for working on this! Yes, I think there should be a change in C API. Currently the C API only accepts the pointers to the gradients and hessians. If we want to know the length of the array allocated outside C API, we must add new parameters. But the change in C API may require changes in the code of every language packages. So maybe doing the check in the Python and R side is preferable. @StrikerRUS WDYT. |
We'd also want to update the docs for C-API to make it explicit that there is a length expectation for the array of floats passed. |
@yaxxie Sorry, I didn't notice what type
Can we use the fact that those arrays should always be the length of
If we cannot do any checks with raw pointers, I'm OK with this way. |
@StrikerRUS I think fix itself is the guarantee this. To check the size in C API side, we must know the exact length of the allocated array in the Python side. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please consider checking some my minor suggestions below:
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
@StrikerRUS Thanks will get around to doing these properly soon |
@yaxxie Thank you! As for modifying C API docs, edit comments in Doxygen format in this file |
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
@StrikerRUS Anything else needed? |
@jameslamb @shiyu1994 Would you like to be a second reviewer for this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks very much for this!
I followed the conversation with @StrikerRUS and @shiyu1994 and understand why we've chosen to do this check on the Python / R side instead of in C/C++.
Please see two small suggestions to make the tests slightly stricter. Would you please also write up a feature request at https://github.com/microsoft/LightGBM/issues documenting the need to do this same work for the R package?
bad_bst_multi = lgb.Booster({'objective': "none", "num_class": len(classes)}, ds_multiclass) | ||
good_bst_multi = lgb.Booster({'objective': "none", "num_class": len(classes)}, ds_multiclass) | ||
good_bst_binary.update(fobj=_good_gradients) | ||
with pytest.raises(ValueError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with pytest.raises(ValueError): | |
with pytest.raises(ValueError, match="number of models per one iteration (1)"): |
Could you please use match
to look for a specific error message? That way, this test won't silently pass if a change results in lightgbm
raising a different, unrelated ValueError
.
with pytest.raises(ValueError): | ||
bad_bst_binary.update(fobj=_bad_gradients) | ||
good_bst_multi.update(fobj=_good_gradients) | ||
with pytest.raises(ValueError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with pytest.raises(ValueError): | |
with pytest.raises(ValueError, match="number of models per one iteration (3)"): |
@yaxxie @StrikerRUS I just changed the title of this PR to hopefully be a bit more informative for the purposes of release notes |
@jameslamb I opened #4905 and pushed commit to address your remarks. Please do let me know if anything else is required. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the test changes and for opening #4905! Please see a few more small suggestions.
bad_bst_multi = lgb.Booster({'objective': "none", "num_class": len(classes)}, ds_multiclass) | ||
good_bst_multi = lgb.Booster({'objective': "none", "num_class": len(classes)}, ds_multiclass) | ||
good_bst_binary.update(fobj=_good_gradients) | ||
with pytest.raises(ValueError, match="number of models per one iteration \(1\)"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with pytest.raises(ValueError, match="number of models per one iteration \(1\)"): | |
with pytest.raises(ValueError, match="number of models per one iteration \\(1\\)"): |
Please see these linting errors seen in https://github.com/microsoft/LightGBM/runs/4613473477?check_suite_focus=true
./tests/python_package_test/test_basic.py:604:78: W605 invalid escape sequence '('
./tests/python_package_test/test_basic.py:604:81: W605 invalid escape sequence ')'
./tests/python_package_test/test_basic.py:607:79: W605 invalid escape sequence '('
./tests/python_package_test/test_basic.py:607:95: W605 invalid escape sequence ')'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't escape any symbols for the readability purpose. Just add a *
symbol at the end:
with pytest.raises(ValueError, match="number of models per one iteration \(1\)"): | |
with pytest.raises(ValueError, match="number of models per one iteration (1) *"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The escape is necessary; (
and )
are characters which means something to the regular expression engine. I'll switch to using re.escape
with pytest.raises(ValueError, match="number of models per one iteration \(1\)"): | ||
bad_bst_binary.update(fobj=_bad_gradients) | ||
good_bst_multi.update(fobj=_good_gradients) | ||
with pytest.raises(ValueError, match=f"number of models per one iteration \({len(classes)}\)"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with pytest.raises(ValueError, match=f"number of models per one iteration \({len(classes)}\)"): | |
with pytest.raises(ValueError, match=f"number of models per one iteration \\({len(classes)}\\)"): |
X = np.random.randn(100, 5) | ||
y_binary = np.random.choice([0, 1], 100) | ||
classes = [0, 1, 2] | ||
y_multiclass = np.random.choice(classes, 100) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
X = np.random.randn(100, 5) | |
y_binary = np.random.choice([0, 1], 100) | |
classes = [0, 1, 2] | |
y_multiclass = np.random.choice(classes, 100) | |
X = np.random.randn(100, 5) | |
y_binary = np.array([0] * 50 + [1] * 50) | |
classes = [0, 1, 2] | |
y_multiclass = np.random.choice([0] * 33 + [1] * 33 + [2] * 34) |
Sorry, just thought of this...can you please remove the randomness from this data construction? Since you're using completely-random data and not testing the produced models, the values don't really matter.
Choosing randomly and using such a small amount of data makes it possible that these tests could fail randomly due to situations like "y_binary
is all 0s". It may seem like a small probability, but consider that the Python tests run about 40 times on every commit to every pull request in this project.
@jameslamb let me know if a946d28 addresses your concerns |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks ok to me, thanks very much for the help!
@@ -572,6 +572,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterRefit(BoosterHandle handle, | |||
/*! | |||
* \brief Update the model by specifying gradient and Hessian directly | |||
* (this can be used to support customized loss functions). | |||
* \note | |||
* The length of the arrays referenced by ``grad`` and ``hess`` must be equal to | |||
* ``num_class * num_train_data``, this is not verified by the library, the caller must ensure this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be this IS verified by the library
, or simply delete this last sentence? Because we are actually verifying this through this pull request.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the C-API docs -- the context here is the lightgbm.so
library, rather than python or R libraries. What we're saying is that a caller to this function LGBM_BoosterUpdateOneIterCustom
is responsible to ensure that the condition is met. I'm happy to tweak the wording, but what the python library does as a convenience to the user is not applicable here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing custom objective function signature in recent commit.
This pull request has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this. |
Using the following code can cause a segfault due to
LGBM_BoosterUpdateOneIterCustom
making an assumption that the passed float arrays match the length of the training data.A faulty
fobj
function can cause at best totally incorrect boosting and at worst segmentation fault. This small patch prevents this from occurring. I noticed it while adding support forLGBM_BoosterUpdateOneIterCustom
to the julia package (see https://github.com/IQVIA-ML/LightGBM.jl/pull/114/files)