From becbbd56ebdc17ddc7a794211b340f888bf1a7bc Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 29 Sep 2021 15:25:50 +0800 Subject: [PATCH] Fix gamma neg log likelihood. Restore non-negative requirement. --- src/metric/elementwise_metric.cu | 3 +-- src/objective/regression_obj.cu | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index 25492bf2c5f8..9ff0cf1419b2 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -309,8 +309,7 @@ struct EvalGammaNLogLik { float constexpr kPsi = 1.0; bst_float theta = -1. / py; bst_float a = kPsi; - // b = -std::log(-theta); - float b = 1.0f; + float b = -std::log(-theta); // c = 1. / kPsi * std::log(y/kPsi) - std::log(y) - common::LogGamma(1. / kPsi); // = 1.0f * std::log(y) - std::log(y) - 0 = 0 float c = 0; diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index ccb3a723d32e..43c27ab064eb 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -421,7 +421,7 @@ class GammaRegression : public ObjFunction { bst_float p = _preds[_idx]; bst_float w = is_null_weight ? 1.0f : _weights[_idx]; bst_float y = _labels[_idx]; - if (y <= 0.0f) { + if (y < 0.0f) { _label_correct[0] = 0; } _out_gpair[_idx] = GradientPair((1 - y / expf(p)) * w, y / expf(p) * w); @@ -433,7 +433,7 @@ class GammaRegression : public ObjFunction { std::vector& label_correct_h = label_correct_.HostVector(); for (auto const flag : label_correct_h) { if (flag == 0) { - LOG(FATAL) << "GammaRegression: label must be positive."; + LOG(FATAL) << "GammaRegression: label must be non-negative."; } } }