Skip to content

Commit

Permalink
[R-package] added tests on LGBM_BoosterResetTrainingData_R (microsoft…
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored and ChipKerchner committed Jun 10, 2020
1 parent 7b66eef commit 23da369
Showing 1 changed file with 63 additions and 0 deletions.
63 changes: 63 additions & 0 deletions R-package/tests/testthat/test_lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,66 @@ test_that("Booster$rollback_one_iter() should work as expected", {
logloss <- bst$eval_train()[[1L]][["value"]]
expect_equal(logloss, 0.027915146)
})

test_that("Booster$update() passing a train_set works as expected", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
nrounds <- 2L

# train with 2 rounds and then update
bst <- lightgbm(
data = as.matrix(agaricus.train$data)
, label = agaricus.train$label
, num_leaves = 4L
, learning_rate = 1.0
, nrounds = nrounds
, objective = "binary"
)
expect_true(lgb.is.Booster(bst))
expect_equal(bst$current_iter(), nrounds)
bst$update(
train_set = Dataset$new(
data = agaricus.train$data
, label = agaricus.train$label
)
)
expect_true(lgb.is.Booster(bst))
expect_equal(bst$current_iter(), nrounds + 1L)

# train with 3 rounds directlry
bst2 <- lightgbm(
data = as.matrix(agaricus.train$data)
, label = agaricus.train$label
, num_leaves = 4L
, learning_rate = 1.0
, nrounds = nrounds + 1L
, objective = "binary"
)
expect_true(lgb.is.Booster(bst2))
expect_equal(bst2$current_iter(), nrounds + 1L)

# model with 2 rounds + 1 update should be identical to 3 rounds
expect_equal(bst2$eval_train()[[1L]][["value"]], 0.04806585)
expect_equal(bst$eval_train()[[1L]][["value"]], bst2$eval_train()[[1L]][["value"]])
})

test_that("Booster$update() throws an informative error if you provide a non-Dataset to update()", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
nrounds <- 2L

# train with 2 rounds and then update
bst <- lightgbm(
data = as.matrix(agaricus.train$data)
, label = agaricus.train$label
, num_leaves = 4L
, learning_rate = 1.0
, nrounds = nrounds
, objective = "binary"
)
expect_error({
bst$update(
train_set = data.frame(x = rnorm(10L))
)
}, regexp = "lgb.Booster.update: Only can use lgb.Dataset", fixed = TRUE)
})

0 comments on commit 23da369

Please sign in to comment.