From e7ab89bce0f89ed209751cb008f32297a1fe9ba2 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Mon, 14 Feb 2022 20:29:20 -0600 Subject: [PATCH 1/3] [R-package] prefer params to keyword argument in lgb.train() --- R-package/R/lgb.cv.R | 4 ++-- R-package/R/lgb.train.R | 4 ++-- R-package/R/utils.R | 13 +++------- R-package/tests/testthat/test_basic.R | 34 +++++++++++++++++++++++++++ 4 files changed, 41 insertions(+), 14 deletions(-) diff --git a/R-package/R/lgb.cv.R b/R-package/R/lgb.cv.R index abe72220c9e4..7dbd75da9aef 100644 --- a/R-package/R/lgb.cv.R +++ b/R-package/R/lgb.cv.R @@ -127,7 +127,7 @@ lgb.cv <- function(params = list() params <- lgb.check.wrapper_param( main_param_name = "objective" , params = params - , alternative_kwarg_value = NULL + , alternative_kwarg_value = obj ) params <- lgb.check.wrapper_param( main_param_name = "early_stopping_round" @@ -137,7 +137,7 @@ lgb.cv <- function(params = list() early_stopping_rounds <- params[["early_stopping_round"]] # extract any function objects passed for objective or metric - params <- lgb.check.obj(params = params, obj = obj) + params <- lgb.check.obj(params = params) fobj <- NULL if (is.function(params$objective)) { fobj <- params$objective diff --git a/R-package/R/lgb.train.R b/R-package/R/lgb.train.R index eebf66ba405f..ce73a851405f 100644 --- a/R-package/R/lgb.train.R +++ b/R-package/R/lgb.train.R @@ -95,7 +95,7 @@ lgb.train <- function(params = list(), params <- lgb.check.wrapper_param( main_param_name = "objective" , params = params - , alternative_kwarg_value = NULL + , alternative_kwarg_value = obj ) params <- lgb.check.wrapper_param( main_param_name = "early_stopping_round" @@ -105,7 +105,7 @@ lgb.train <- function(params = list(), early_stopping_rounds <- params[["early_stopping_round"]] # extract any function objects passed for objective or metric - params <- lgb.check.obj(params = params, obj = obj) + params <- lgb.check.obj(params = params) fobj <- NULL if (is.function(params$objective)) { fobj <- params$objective diff --git a/R-package/R/utils.R b/R-package/R/utils.R index 3cdab4dcfd08..86b3624f482f 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -117,7 +117,7 @@ lgb.check_interaction_constraints <- function(interaction_constraints, column_na } -lgb.check.obj <- function(params, obj) { +lgb.check.obj <- function(params) { # List known objectives in a vector OBJECTIVES <- c( @@ -158,25 +158,18 @@ lgb.check.obj <- function(params, obj) { , "xendcg_mart" ) - # Check whether the objective is empty or not, and take it from params if needed - if (!is.null(obj)) { - params$objective <- obj + if (is.null(params$objective)) { + stop("lgb.check.obj: objective should be a character or a function") } - # Check whether the objective is a character if (is.character(params$objective)) { - # If the objective is a character, check if it is a known objective if (!(params$objective %in% OBJECTIVES)) { stop("lgb.check.obj: objective name error should be one of (", paste0(OBJECTIVES, collapse = ", "), ")") } - } else if (!is.function(params$objective)) { - - stop("lgb.check.obj: objective should be a character or a function") - } return(params) diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index 8b6f5f6ceb44..a05c5c516d5a 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -569,6 +569,23 @@ test_that("lgb.cv() respects parameter aliases for objective", { expect_length(cv_bst$boosters, nfold) }) +test_that("lgb.cv() prefers objective in params to keyword argument", { + data("EuStockMarkets") + bst <- lgb.train( + data = lgb.Dataset( + data = EuStockMarkets[, c("SMI", "CAC", "FTSE")] + , label = EuStockMarkets[, "DAX"] + ) + , params = list( + objective = "regression_l2" + , verbosity = -1L + ) + , nrounds = 5L + , obj = "regression_l1" + ) + expect_equal(bst$params$objective, "regression_l2") +}) + test_that("lgb.cv() respects parameter aliases for metric", { nrounds <- 3L nfold <- 4L @@ -684,6 +701,23 @@ test_that("lgb.train() respects parameter aliases for objective", { expect_equal(bst$params[["objective"]], "binary") }) +test_that("lgb.train() prefers objective in params to keyword argument", { + data("EuStockMarkets") + bst <- lgb.train( + data = lgb.Dataset( + data = EuStockMarkets[, c("SMI", "CAC", "FTSE")] + , label = EuStockMarkets[, "DAX"] + ) + , params = list( + objective = "regression_l2" + , verbosity = -1L + ) + , nrounds = 5L + , obj = "regression_l1" + ) + expect_equal(bst$params$objective, "regression_l2") +}) + test_that("lgb.train() respects parameter aliases for metric", { nrounds <- 3L dtrain <- lgb.Dataset( From 297a0a561c37f4facbcf7c77b0328c4fc096d40c Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 15 Feb 2022 19:50:51 -0600 Subject: [PATCH 2/3] make test stricter --- R-package/tests/testthat/test_basic.R | 37 ++++++++++++++++++++------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index a05c5c516d5a..13dd7ce23920 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -571,19 +571,30 @@ test_that("lgb.cv() respects parameter aliases for objective", { test_that("lgb.cv() prefers objective in params to keyword argument", { data("EuStockMarkets") - bst <- lgb.train( + cv_bst <- lgb.cv( data = lgb.Dataset( data = EuStockMarkets[, c("SMI", "CAC", "FTSE")] , label = EuStockMarkets[, "DAX"] ) , params = list( - objective = "regression_l2" + application = "regression_l1" , verbosity = -1L ) , nrounds = 5L - , obj = "regression_l1" - ) - expect_equal(bst$params$objective, "regression_l2") + , obj = "regression_l2" + ) + for (bst_list in cv_bst$boosters) { + bst <- bst_list[["booster"]] + expect_equal(bst$params$objective, "regression_l1") + # NOTE: using save_model_to_string() since that is the simplest public API in the R package + # allowing access to the "objective" attribute of the Booster object on the C++ side + model_txt_lines <- strsplit( + x = bst$save_model_to_string() + , split = "\n" + )[[1L]] + expect_true(any(model_txt_lines == "objective=regression_l1")) + expect_false(any(model_txt_lines == "objective=regression_l2")) + } }) test_that("lgb.cv() respects parameter aliases for metric", { @@ -709,13 +720,21 @@ test_that("lgb.train() prefers objective in params to keyword argument", { , label = EuStockMarkets[, "DAX"] ) , params = list( - objective = "regression_l2" + loss = "regression_l1" , verbosity = -1L ) , nrounds = 5L - , obj = "regression_l1" - ) - expect_equal(bst$params$objective, "regression_l2") + , obj = "regression_l2" + ) + expect_equal(bst$params$objective, "regression_l1") + # NOTE: using save_model_to_string() since that is the simplest public API in the R package + # allowing access to the "objective" attribute of the Booster object on the C++ side + model_txt_lines <- strsplit( + x = bst$save_model_to_string() + , split = "\n" + )[[1L]] + expect_true(any(model_txt_lines == "objective=regression_l1")) + expect_false(any(model_txt_lines == "objective=regression_l2")) }) test_that("lgb.train() respects parameter aliases for metric", { From 0ee178a03c2375b12a0e0307432bfd8a4ff9aa7d Mon Sep 17 00:00:00 2001 From: James Lamb Date: Wed, 16 Feb 2022 20:27:56 -0600 Subject: [PATCH 3/3] Apply suggestions from code review Co-authored-by: Nikita Titov --- R-package/tests/testthat/test_basic.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index 13dd7ce23920..0dd0098243d6 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -578,7 +578,7 @@ test_that("lgb.cv() prefers objective in params to keyword argument", { ) , params = list( application = "regression_l1" - , verbosity = -1L + , verbosity = VERBOSITY ) , nrounds = 5L , obj = "regression_l2" @@ -721,7 +721,7 @@ test_that("lgb.train() prefers objective in params to keyword argument", { ) , params = list( loss = "regression_l1" - , verbosity = -1L + , verbosity = VERBOSITY ) , nrounds = 5L , obj = "regression_l2"