Skip to content

Commit

Permalink
Fix gamma neg log likelihood.
Browse files Browse the repository at this point in the history
Restore non-negative requirement.
  • Loading branch information
trivialfis committed Sep 29, 2021
1 parent b2d8431 commit becbbd5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
3 changes: 1 addition & 2 deletions src/metric/elementwise_metric.cu
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,7 @@ struct EvalGammaNLogLik {
float constexpr kPsi = 1.0;
bst_float theta = -1. / py;
bst_float a = kPsi;
// b = -std::log(-theta);
float b = 1.0f;
float b = -std::log(-theta);
// c = 1. / kPsi * std::log(y/kPsi) - std::log(y) - common::LogGamma(1. / kPsi);
// = 1.0f * std::log(y) - std::log(y) - 0 = 0
float c = 0;
Expand Down
4 changes: 2 additions & 2 deletions src/objective/regression_obj.cu
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ class GammaRegression : public ObjFunction {
bst_float p = _preds[_idx];
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
bst_float y = _labels[_idx];
if (y <= 0.0f) {
if (y < 0.0f) {
_label_correct[0] = 0;
}
_out_gpair[_idx] = GradientPair((1 - y / expf(p)) * w, y / expf(p) * w);
Expand All @@ -433,7 +433,7 @@ class GammaRegression : public ObjFunction {
std::vector<int>& label_correct_h = label_correct_.HostVector();
for (auto const flag : label_correct_h) {
if (flag == 0) {
LOG(FATAL) << "GammaRegression: label must be positive.";
LOG(FATAL) << "GammaRegression: label must be non-negative.";
}
}
}
Expand Down

0 comments on commit becbbd5

Please sign in to comment.