Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] prefer params to keyword argument in lgb.train() #5007

Merged
merged 5 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ lgb.cv <- function(params = list()
params <- lgb.check.wrapper_param(
main_param_name = "objective"
, params = params
, alternative_kwarg_value = NULL
, alternative_kwarg_value = obj
)
params <- lgb.check.wrapper_param(
main_param_name = "early_stopping_round"
Expand All @@ -137,7 +137,7 @@ lgb.cv <- function(params = list()
early_stopping_rounds <- params[["early_stopping_round"]]

# extract any function objects passed for objective or metric
params <- lgb.check.obj(params = params, obj = obj)
params <- lgb.check.obj(params = params)
fobj <- NULL
if (is.function(params$objective)) {
fobj <- params$objective
Expand Down
4 changes: 2 additions & 2 deletions R-package/R/lgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ lgb.train <- function(params = list(),
params <- lgb.check.wrapper_param(
main_param_name = "objective"
, params = params
, alternative_kwarg_value = NULL
, alternative_kwarg_value = obj
)
params <- lgb.check.wrapper_param(
main_param_name = "early_stopping_round"
Expand All @@ -105,7 +105,7 @@ lgb.train <- function(params = list(),
early_stopping_rounds <- params[["early_stopping_round"]]

# extract any function objects passed for objective or metric
params <- lgb.check.obj(params = params, obj = obj)
params <- lgb.check.obj(params = params)
fobj <- NULL
if (is.function(params$objective)) {
fobj <- params$objective
Expand Down
13 changes: 3 additions & 10 deletions R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ lgb.check_interaction_constraints <- function(interaction_constraints, column_na

}

lgb.check.obj <- function(params, obj) {
lgb.check.obj <- function(params) {

# List known objectives in a vector
OBJECTIVES <- c(
Expand Down Expand Up @@ -158,25 +158,18 @@ lgb.check.obj <- function(params, obj) {
, "xendcg_mart"
)

# Check whether the objective is empty or not, and take it from params if needed
if (!is.null(obj)) {
params$objective <- obj
if (is.null(params$objective)) {
stop("lgb.check.obj: objective should be a character or a function")
}

# Check whether the objective is a character
if (is.character(params$objective)) {

# If the objective is a character, check if it is a known objective
if (!(params$objective %in% OBJECTIVES)) {

stop("lgb.check.obj: objective name error should be one of (", paste0(OBJECTIVES, collapse = ", "), ")")
Copy link
Collaborator

@StrikerRUS StrikerRUS Feb 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you remember why is this line needed? Or more generally: why is lgb.check.obj() function needed?
Looks like there is a duplication with a similar check at cpp side. For example, there is no similar logic in the Python-package, but users can pass only registered objective names. Please consider the following example:

import numpy as np
import lightgbm as lgb

X = np.random.random((30, 5))
y = np.random.random((30))

lgb.LGBMRegressor(objective='no_such_objective').fit(X, y)

and the output:

---------------------------------------------------------------------------
LightGBMError                             Traceback (most recent call last)
<ipython-input-7-adb0d6a9439c> in <module>
----> 1 lgb.LGBMRegressor(objective='no_such_objective').fit(X, y)

D:\Miniconda3\lib\site-packages\lightgbm\sklearn.py in fit(self, X, y, sample_weight, init_score, eval_set, eval_names, eval_sample_weight, eval_init_score, eval_metric, early_stopping_rounds, verbose, feature_name, categorical_feature, callbacks, init_model)
    893             callbacks=None, init_model=None):
    894         """Docstring is inherited from the LGBMModel."""
--> 895         super().fit(X, y, sample_weight=sample_weight, init_score=init_score,
    896                     eval_set=eval_set, eval_names=eval_names, eval_sample_weight=eval_sample_weight,
    897                     eval_init_score=eval_init_score, eval_metric=eval_metric,

D:\Miniconda3\lib\site-packages\lightgbm\sklearn.py in fit(self, X, y, sample_weight, init_score, group, eval_set, eval_names, eval_sample_weight, eval_class_weight, eval_init_score, eval_group, eval_metric, early_stopping_rounds, verbose, feature_name, categorical_feature, callbacks, init_model)
    746         callbacks.append(record_evaluation(evals_result))
    747 
--> 748         self._Booster = train(
    749             params=params,
    750             train_set=train_set,

D:\Miniconda3\lib\site-packages\lightgbm\engine.py in train(params, train_set, num_boost_round, valid_sets, valid_names, fobj, feval, init_model, feature_name, categorical_feature, early_stopping_rounds, evals_result, verbose_eval, learning_rates, keep_training_booster, callbacks)
    269     # construct booster
    270     try:
--> 271         booster = Booster(params=params, train_set=train_set)
    272         if is_valid_contain_train:
    273             booster.set_train_data_name(train_data_name)

D:\Miniconda3\lib\site-packages\lightgbm\basic.py in __init__(self, params, train_set, model_file, model_str, silent)
   2608             params_str = param_dict_to_str(params)
   2609             self.handle = ctypes.c_void_p()
-> 2610             _safe_call(_LIB.LGBM_BoosterCreate(
   2611                 train_set.handle,
   2612                 c_str(params_str),

D:\Miniconda3\lib\site-packages\lightgbm\basic.py in _safe_call(ret)
    123     """
    124     if ret != 0:
--> 125         raise LightGBMError(_LIB.LGBM_GetLastError().decode('utf-8'))
    126 
    127 

LightGBMError: Unknown objective type name: no_such_objective

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks to me like lgb.check.obj() has been a part of the R package since the very first PR introducing the R package, back in 2017.

https://github.com/microsoft/LightGBM/blame/9259a5318cd03951f2eefe378d9fb8db74457ee3/R-package/R/utils.R#L120

#168

I agree that in the current version of the R package, this isn't really necessary any more. And I'd support removing it, to remove the need to maintain a hard-coded list of objective names in the R package.

library(lightgbm)

data("EuStockMarkets")
lightgbm:::Booster$new(
    train_set = lightgbm::lgb.Dataset(
        data = EuStockMarkets[, c("SMI", "CAC", "FTSE")]
        , label = EuStockMarkets[, "DAX"]
    )
    , params = list(
        objective = "no_such_objective"
    )
)

Raises the following exception:

[LightGBM] [Fatal] Unknown objective type name: no_such_objective
Error in try({ : Unknown objective type name: no_such_objective
Error in initialize(...) : lgb.Booster: cannot create Booster handle

I'll open a separate PR tonight factoring it out.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks!


}

} else if (!is.function(params$objective)) {

stop("lgb.check.obj: objective should be a character or a function")

}

return(params)
Expand Down
34 changes: 34 additions & 0 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,23 @@ test_that("lgb.cv() respects parameter aliases for objective", {
expect_length(cv_bst$boosters, nfold)
})

test_that("lgb.cv() prefers objective in params to keyword argument", {
data("EuStockMarkets")
bst <- lgb.train(
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
data = lgb.Dataset(
data = EuStockMarkets[, c("SMI", "CAC", "FTSE")]
, label = EuStockMarkets[, "DAX"]
)
, params = list(
objective = "regression_l2"
, verbosity = -1L
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
)
, nrounds = 5L
, obj = "regression_l1"
)
expect_equal(bst$params$objective, "regression_l2")
})

test_that("lgb.cv() respects parameter aliases for metric", {
nrounds <- 3L
nfold <- 4L
Expand Down Expand Up @@ -684,6 +701,23 @@ test_that("lgb.train() respects parameter aliases for objective", {
expect_equal(bst$params[["objective"]], "binary")
})

test_that("lgb.train() prefers objective in params to keyword argument", {
data("EuStockMarkets")
bst <- lgb.train(
data = lgb.Dataset(
data = EuStockMarkets[, c("SMI", "CAC", "FTSE")]
, label = EuStockMarkets[, "DAX"]
)
, params = list(
objective = "regression_l2"
, verbosity = -1L
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
)
, nrounds = 5L
, obj = "regression_l1"
)
expect_equal(bst$params$objective, "regression_l2")
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
})

test_that("lgb.train() respects parameter aliases for metric", {
nrounds <- 3L
dtrain <- lgb.Dataset(
Expand Down