Skip to content

Commit

Permalink
fix calculation of weighted gamma loss (fixes #4174) (#4283)
Browse files Browse the repository at this point in the history
* fixed weighted gamma obj

* added unit tests

* fixing linter errors

* another linter

* set seed

* fix linter (integer seed)
  • Loading branch information
mayer79 authored May 21, 2021
1 parent 237ac29 commit 4b1b412
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
67 changes: 67 additions & 0 deletions R-package/tests/testthat/test_weighted_loss.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
context("Case weights are respected")

test_that("Gamma regression reacts on 'weight'", {
n <- 100L
set.seed(87L)
X <- matrix(runif(2L * n), ncol = 2L)
y <- X[, 1L] + X[, 2L] + runif(n)
X_pred <- X[1L:5L, ]

params <- list(objective = "gamma")

# Unweighted
dtrain <- lgb.Dataset(X, label = y)
bst <- lgb.train(
params = params
, data = dtrain
, nrounds = 4L
, verbose = 0L
)
pred_unweighted <- predict(bst, X_pred)

# Constant weight 1
dtrain <- lgb.Dataset(
X
, label = y
, weight = rep(1.0, n)
)
bst <- lgb.train(
params = params
, data = dtrain
, nrounds = 4L
, verbose = 0L
)
pred_weighted_1 <- predict(bst, X_pred)

# Constant weight 2
dtrain <- lgb.Dataset(
X
, label = y
, weight = rep(2.0, n)
)
bst <- lgb.train(
params = params
, data = dtrain
, nrounds = 4L
, verbose = 0L
)
pred_weighted_2 <- predict(bst, X_pred)

# Non-constant weights
dtrain <- lgb.Dataset(
X
, label = y
, weight = seq(0.0, 1.0, length.out = n)
)
bst <- lgb.train(
params = params
, data = dtrain
, nrounds = 4L
, verbose = 0L
)
pred_weighted <- predict(bst, X_pred)

expect_equal(pred_unweighted, pred_weighted_1)
expect_equal(pred_weighted_1, pred_weighted_2)
expect_false(all(pred_unweighted == pred_weighted))
})
2 changes: 1 addition & 1 deletion src/objective/regression_objective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ class RegressionGammaLoss : public RegressionPoissonLoss {
} else {
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) {
gradients[i] = static_cast<score_t>(1.0 - label_[i] * std::exp(-score[i]) * weights_[i]);
gradients[i] = static_cast<score_t>((1.0 - label_[i] * std::exp(-score[i])) * weights_[i]);
hessians[i] = static_cast<score_t>(label_[i] * std::exp(-score[i]) * weights_[i]);
}
}
Expand Down

0 comments on commit 4b1b412

Please sign in to comment.