-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Calculate base_score
based on input labels.
#8107
Conversation
0a11b8a
to
10c9c8c
Compare
The old binary format and the |
721c4ce
to
8b55e9c
Compare
I workaround it by limiting the base_score to a single scalar for now. |
// - model loaded from new binary or JSON. | ||
// - model is created from scratch. | ||
// - model is configured second time due to change of parameter | ||
CHECK(obj_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This configuration is very fragile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. That's why I really want to remove the old model format.
fdd655c
to
c22010b
Compare
I removed the use of |
b716145
to
8890a2a
Compare
src/objective/regression_obj.cu
Outdated
// average base score across all valid workers | ||
rabit::Allreduce<rabit::op::Sum>(out.Values().data(), out.Values().size()); | ||
std::transform(linalg::cbegin(out), linalg::cend(out), linalg::begin(out), | ||
[world](float v) { return v / world; }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better that it was before. I wonder if it can be more robust with weighted averaging. The MSE version will need to use a weighted average also. Small example:
Worker 0 labels: 0 0 0
Worker 1 labels: 1000
True median: 0
True median mean abs error: 250
Estimated median (current method): 500
Estimated median (current method) mean abs error: 500
Estimated median (weighted average): 250
Estimated median (weighted average) abs error: 375
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the suggestion, changed it to the weighted average. I have adapted your example into a Python test.
#4321 .
This PR calculates the
base_score
from labels for l1 regression and saves it to the output model. Will follow up on other objectives as well.Multi-target and multi-class are not yet supported due to the binary model parameter.