diff --git a/R-package/R/lgb.cv.R b/R-package/R/lgb.cv.R index 7dbd75da9aef..fdc93c2f2913 100644 --- a/R-package/R/lgb.cv.R +++ b/R-package/R/lgb.cv.R @@ -137,7 +137,6 @@ lgb.cv <- function(params = list() early_stopping_rounds <- params[["early_stopping_round"]] # extract any function objects passed for objective or metric - params <- lgb.check.obj(params = params) fobj <- NULL if (is.function(params$objective)) { fobj <- params$objective diff --git a/R-package/R/lgb.train.R b/R-package/R/lgb.train.R index ce73a851405f..89b12d0bb03d 100644 --- a/R-package/R/lgb.train.R +++ b/R-package/R/lgb.train.R @@ -105,7 +105,6 @@ lgb.train <- function(params = list(), early_stopping_rounds <- params[["early_stopping_round"]] # extract any function objects passed for objective or metric - params <- lgb.check.obj(params = params) fobj <- NULL if (is.function(params$objective)) { fobj <- params$objective diff --git a/R-package/R/utils.R b/R-package/R/utils.R index 86b3624f482f..c89bfe9fb0b2 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -117,64 +117,6 @@ lgb.check_interaction_constraints <- function(interaction_constraints, column_na } -lgb.check.obj <- function(params) { - - # List known objectives in a vector - OBJECTIVES <- c( - "regression" - , "regression_l1" - , "regression_l2" - , "mean_squared_error" - , "mse" - , "l2_root" - , "root_mean_squared_error" - , "rmse" - , "mean_absolute_error" - , "mae" - , "quantile" - , "huber" - , "fair" - , "poisson" - , "binary" - , "lambdarank" - , "multiclass" - , "softmax" - , "multiclassova" - , "multiclass_ova" - , "ova" - , "ovr" - , "xentropy" - , "cross_entropy" - , "xentlambda" - , "cross_entropy_lambda" - , "mean_absolute_percentage_error" - , "mape" - , "gamma" - , "tweedie" - , "rank_xendcg" - , "xendcg" - , "xe_ndcg" - , "xe_ndcg_mart" - , "xendcg_mart" - ) - - if (is.null(params$objective)) { - stop("lgb.check.obj: objective should be a character or a function") - } - - if (is.character(params$objective)) { - - if (!(params$objective %in% OBJECTIVES)) { - - stop("lgb.check.obj: objective name error should be one of (", paste0(OBJECTIVES, collapse = ", "), ")") - - } - - } - - return(params) - -} # [description] # Take any character values from eval and store them in params$metric. diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index 9188e0a4e719..ab5accab6144 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -520,6 +520,22 @@ test_that("lgb.cv() respects showsd argument", { expect_identical(evals_no_showsd[["eval_err"]], list()) }) +test_that("lgb.cv() raises an informative error for unrecognized objectives", { + dtrain <- lgb.Dataset( + data = train$data + , label = train$label + ) + expect_error({ + bst <- lgb.cv( + data = dtrain + , params = list( + objective_type = "not_a_real_objective" + , verbosity = VERBOSITY + ) + ) + }, regexp = "Unknown objective type name: not_a_real_objective") +}) + test_that("lgb.cv() respects parameter aliases for objective", { nrounds <- 3L nfold <- 4L @@ -663,6 +679,22 @@ test_that("lgb.train() works as expected with multiple eval metrics", { ) }) +test_that("lgb.train() raises an informative error for unrecognized objectives", { + dtrain <- lgb.Dataset( + data = train$data + , label = train$label + ) + expect_error({ + bst <- lgb.train( + data = dtrain + , params = list( + objective_type = "not_a_real_objective" + , verbosity = VERBOSITY + ) + ) + }, regexp = "Unknown objective type name: not_a_real_objective") +}) + test_that("lgb.train() respects parameter aliases for objective", { nrounds <- 3L dtrain <- lgb.Dataset(