Skip to content

Commit

Permalink
[R-package] remove internal function lgb.check.obj() (#5021)
Browse files Browse the repository at this point in the history
* factor out lgb.check.obj()

* remove lgb.check.obj()

* add test on lgb.cv()
  • Loading branch information
jameslamb authored Feb 23, 2022
1 parent 406bc7d commit a1fbe84
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 60 deletions.
1 change: 0 additions & 1 deletion R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ lgb.cv <- function(params = list()
early_stopping_rounds <- params[["early_stopping_round"]]

# extract any function objects passed for objective or metric
params <- lgb.check.obj(params = params)
fobj <- NULL
if (is.function(params$objective)) {
fobj <- params$objective
Expand Down
1 change: 0 additions & 1 deletion R-package/R/lgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ lgb.train <- function(params = list(),
early_stopping_rounds <- params[["early_stopping_round"]]

# extract any function objects passed for objective or metric
params <- lgb.check.obj(params = params)
fobj <- NULL
if (is.function(params$objective)) {
fobj <- params$objective
Expand Down
58 changes: 0 additions & 58 deletions R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -117,64 +117,6 @@ lgb.check_interaction_constraints <- function(interaction_constraints, column_na

}

lgb.check.obj <- function(params) {

# List known objectives in a vector
OBJECTIVES <- c(
"regression"
, "regression_l1"
, "regression_l2"
, "mean_squared_error"
, "mse"
, "l2_root"
, "root_mean_squared_error"
, "rmse"
, "mean_absolute_error"
, "mae"
, "quantile"
, "huber"
, "fair"
, "poisson"
, "binary"
, "lambdarank"
, "multiclass"
, "softmax"
, "multiclassova"
, "multiclass_ova"
, "ova"
, "ovr"
, "xentropy"
, "cross_entropy"
, "xentlambda"
, "cross_entropy_lambda"
, "mean_absolute_percentage_error"
, "mape"
, "gamma"
, "tweedie"
, "rank_xendcg"
, "xendcg"
, "xe_ndcg"
, "xe_ndcg_mart"
, "xendcg_mart"
)

if (is.null(params$objective)) {
stop("lgb.check.obj: objective should be a character or a function")
}

if (is.character(params$objective)) {

if (!(params$objective %in% OBJECTIVES)) {

stop("lgb.check.obj: objective name error should be one of (", paste0(OBJECTIVES, collapse = ", "), ")")

}

}

return(params)

}

# [description]
# Take any character values from eval and store them in params$metric.
Expand Down
32 changes: 32 additions & 0 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,22 @@ test_that("lgb.cv() respects showsd argument", {
expect_identical(evals_no_showsd[["eval_err"]], list())
})

test_that("lgb.cv() raises an informative error for unrecognized objectives", {
dtrain <- lgb.Dataset(
data = train$data
, label = train$label
)
expect_error({
bst <- lgb.cv(
data = dtrain
, params = list(
objective_type = "not_a_real_objective"
, verbosity = VERBOSITY
)
)
}, regexp = "Unknown objective type name: not_a_real_objective")
})

test_that("lgb.cv() respects parameter aliases for objective", {
nrounds <- 3L
nfold <- 4L
Expand Down Expand Up @@ -663,6 +679,22 @@ test_that("lgb.train() works as expected with multiple eval metrics", {
)
})

test_that("lgb.train() raises an informative error for unrecognized objectives", {
dtrain <- lgb.Dataset(
data = train$data
, label = train$label
)
expect_error({
bst <- lgb.train(
data = dtrain
, params = list(
objective_type = "not_a_real_objective"
, verbosity = VERBOSITY
)
)
}, regexp = "Unknown objective type name: not_a_real_objective")
})

test_that("lgb.train() respects parameter aliases for objective", {
nrounds <- 3L
dtrain <- lgb.Dataset(
Expand Down

0 comments on commit a1fbe84

Please sign in to comment.