From 4b1b412452218c5be5ac0f238454ec9309036798 Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Fri, 21 May 2021 14:28:10 +0200 Subject: [PATCH] fix calculation of weighted gamma loss (fixes #4174) (#4283) * fixed weighted gamma obj * added unit tests * fixing linter errors * another linter * set seed * fix linter (integer seed) --- R-package/tests/testthat/test_weighted_loss.R | 67 +++++++++++++++++++ src/objective/regression_objective.hpp | 2 +- 2 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 R-package/tests/testthat/test_weighted_loss.R diff --git a/R-package/tests/testthat/test_weighted_loss.R b/R-package/tests/testthat/test_weighted_loss.R new file mode 100644 index 000000000000..752f80ce27e6 --- /dev/null +++ b/R-package/tests/testthat/test_weighted_loss.R @@ -0,0 +1,67 @@ +context("Case weights are respected") + +test_that("Gamma regression reacts on 'weight'", { + n <- 100L + set.seed(87L) + X <- matrix(runif(2L * n), ncol = 2L) + y <- X[, 1L] + X[, 2L] + runif(n) + X_pred <- X[1L:5L, ] + + params <- list(objective = "gamma") + + # Unweighted + dtrain <- lgb.Dataset(X, label = y) + bst <- lgb.train( + params = params + , data = dtrain + , nrounds = 4L + , verbose = 0L + ) + pred_unweighted <- predict(bst, X_pred) + + # Constant weight 1 + dtrain <- lgb.Dataset( + X + , label = y + , weight = rep(1.0, n) + ) + bst <- lgb.train( + params = params + , data = dtrain + , nrounds = 4L + , verbose = 0L + ) + pred_weighted_1 <- predict(bst, X_pred) + + # Constant weight 2 + dtrain <- lgb.Dataset( + X + , label = y + , weight = rep(2.0, n) + ) + bst <- lgb.train( + params = params + , data = dtrain + , nrounds = 4L + , verbose = 0L + ) + pred_weighted_2 <- predict(bst, X_pred) + + # Non-constant weights + dtrain <- lgb.Dataset( + X + , label = y + , weight = seq(0.0, 1.0, length.out = n) + ) + bst <- lgb.train( + params = params + , data = dtrain + , nrounds = 4L + , verbose = 0L + ) + pred_weighted <- predict(bst, X_pred) + + expect_equal(pred_unweighted, pred_weighted_1) + expect_equal(pred_weighted_1, pred_weighted_2) + expect_false(all(pred_unweighted == pred_weighted)) +}) diff --git a/src/objective/regression_objective.hpp b/src/objective/regression_objective.hpp index 753224bd5603..e711da012066 100644 --- a/src/objective/regression_objective.hpp +++ b/src/objective/regression_objective.hpp @@ -695,7 +695,7 @@ class RegressionGammaLoss : public RegressionPoissonLoss { } else { #pragma omp parallel for schedule(static) for (data_size_t i = 0; i < num_data_; ++i) { - gradients[i] = static_cast(1.0 - label_[i] * std::exp(-score[i]) * weights_[i]); + gradients[i] = static_cast((1.0 - label_[i] * std::exp(-score[i])) * weights_[i]); hessians[i] = static_cast(label_[i] * std::exp(-score[i]) * weights_[i]); } }