Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] added support for first_metric_only (fixes #2368) #2912

Merged
merged 17 commits into from
Sep 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions R-package/R/aliases.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,17 @@
)
return(c(learning_params, .DATASET_PARAMETERS()))
}

# [description]
# Per https://github.com/microsoft/LightGBM/blob/master/docs/Parameters.rst#metric,
# a few different strings can be used to indicate "no metrics".
# [returns]
# A character vector
.NO_METRIC_STRINGS <- function() {
return(c(
"na"
, "None"
, "null"
, "custom"
))
}
12 changes: 10 additions & 2 deletions R-package/R/callback.R
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ cb.record.evaluation <- function() {

}

cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
cb.early.stop <- function(stopping_rounds, first_metric_only = FALSE, verbose = TRUE) {

# Initialize variables
factor_to_bigger_better <- NULL
Expand Down Expand Up @@ -325,8 +325,16 @@ cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
# Store iteration
cur_iter <- env$iteration

# By default, any metric can trigger early stopping. This can be disabled
# with 'first_metric_only = TRUE'
if (isTRUE(first_metric_only)) {
evals_to_check <- 1L
} else {
evals_to_check <- seq_len(eval_len)
}

# Loop through evaluation
for (i in seq_len(eval_len)) {
for (i in evals_to_check) {

# Store score
score <- env$eval_list[[i]]$value * factor_to_bigger_better[i]
Expand Down
38 changes: 27 additions & 11 deletions R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ CVBooster <- R6::R6Class(
#' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples.
#' @param label Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}
#' @param weight vector of response values. If not NULL, will set to dataset
#' @param obj objective function, can be character or custom objective function. Examples include
#' \code{regression}, \code{regression_l1}, \code{huber},
#' \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}
#' @param eval evaluation function, can be (list of) character or custom eval function
#' @param record Boolean, TRUE will record iteration message to \code{booster$record_evals}
#' @param showsd \code{boolean}, whether to show standard deviation of cross validation
#' @param stratified a \code{boolean} indicating whether sampling of folds should be stratified
Expand All @@ -52,7 +48,7 @@ CVBooster <- R6::R6Class(
#' the number of real CPU cores, not the number of threads (most
#' CPU using hyper-threading to generate 2 threads per CPU core).}
#' }
#'
#' @inheritSection lgb_shared_params Early Stopping
#' @return a trained model \code{lgb.CVBooster}.
#'
#' @examples
Expand Down Expand Up @@ -114,17 +110,25 @@ lgb.cv <- function(params = list()
params <- lgb.check.obj(params, obj)
params <- lgb.check.eval(params, eval)
fobj <- NULL
feval <- NULL
eval_functions <- list(NULL)

# Check for objective (function or not)
if (is.function(params$objective)) {
fobj <- params$objective
params$objective <- "NONE"
}

# Check for loss (function or not)
# If loss is a single function, store it as a 1-element list
# (for backwards compatibility). If it is a list of functions, store
# all of them
if (is.function(eval)) {
feval <- eval
eval_functions <- list(eval)
}
if (methods::is(eval, "list")) {
eval_functions <- Filter(
f = is.function
, x = eval
)
}

# Init predictor to empty
Expand Down Expand Up @@ -266,6 +270,7 @@ lgb.cv <- function(params = list()
callbacks
, cb.early.stop(
stopping_rounds = early_stopping_rounds
, first_metric_only = isTRUE(params[["first_metric_only"]])
, verbose = verbose
)
)
Expand Down Expand Up @@ -357,7 +362,11 @@ lgb.cv <- function(params = list()
# Update one boosting iteration
msg <- lapply(cv_booster$boosters, function(fd) {
fd$booster$update(fobj = fobj)
fd$booster$eval_valid(feval = feval)
out <- list()
for (eval_function in eval_functions) {
out <- append(out, fd$booster$eval_valid(feval = eval_function))
}
return(out)
})

# Prepare collection of evaluation results
Expand All @@ -384,7 +393,13 @@ lgb.cv <- function(params = list()
# When early stopping is not activated, we compute the best iteration / score ourselves
# based on the first first metric
if (record && is.na(env$best_score)) {
first_metric <- cv_booster$boosters[[1L]][[1L]]$.__enclos_env__$private$eval_names[1L]
# when using a custom eval function, the metric name is returned from the
# function, so figure it out from record_evals
if (!is.null(eval_functions[1L])) {
first_metric <- names(cv_booster$record_evals[["valid"]])[1L]
} else {
first_metric <- cv_booster$.__enclos_env__$private$eval_names[1L]
}
.find_best <- which.min
if (isTRUE(env$eval_list[[1L]]$higher_better[1L])) {
.find_best <- which.max
Expand Down Expand Up @@ -576,7 +591,8 @@ lgb.merge.cv.result <- function(msg, showsd = TRUE) {
msg[[i]][[j]]$value }))
})

# Get evaluation
# Get evaluation. Just taking the first element here to
# get structture (name, higher_bettter, data_name)
ret_eval <- msg[[1L]]

# Go through evaluation length items
Expand Down
47 changes: 34 additions & 13 deletions R-package/R/lgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
#' @description Logic to train with LightGBM
#' @inheritParams lgb_shared_params
#' @param valids a list of \code{lgb.Dataset} objects, used for validation
#' @param obj objective function, can be character or custom objective function. Examples include
#' \code{regression}, \code{regression_l1}, \code{huber},
#' \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}
#' @param eval evaluation function, can be (a list of) character or custom eval function
#' @param record Boolean, TRUE will record iteration message to \code{booster$record_evals}
#' @param colnames feature names, if not null, will use this to overwrite the names in dataset
#' @param categorical_feature list of str or int
Expand All @@ -26,6 +22,7 @@
#' the number of real CPU cores, not the number of threads (most
#' CPU using hyper-threading to generate 2 threads per CPU core).}
#' }
#' @inheritSection lgb_shared_params Early Stopping
#' @return a trained booster model \code{lgb.Booster}.
#'
#' @examples
Expand Down Expand Up @@ -90,17 +87,25 @@ lgb.train <- function(params = list(),
params <- lgb.check.obj(params, obj)
params <- lgb.check.eval(params, eval)
fobj <- NULL
feval <- NULL
eval_functions <- list(NULL)

# Check for objective (function or not)
if (is.function(params$objective)) {
fobj <- params$objective
params$objective <- "NONE"
}

# Check for loss (function or not)
# If loss is a single function, store it as a 1-element list
# (for backwards compatibility). If it is a list of functions, store
# all of them
if (is.function(eval)) {
feval <- eval
eval_functions <- list(eval)
}
if (methods::is(eval, "list")) {
eval_functions <- Filter(
f = is.function
, x = eval
)
}

# Init predictor to empty
Expand Down Expand Up @@ -235,6 +240,7 @@ lgb.train <- function(params = list(),
callbacks
, cb.early.stop(
stopping_rounds = early_stopping_rounds
, first_metric_only = isTRUE(params[["first_metric_only"]])
, verbose = verbose
)
)
Expand Down Expand Up @@ -280,13 +286,28 @@ lgb.train <- function(params = list(),
# Collection: Has validation dataset?
if (length(valids) > 0L) {

# Validation has training dataset?
if (valid_contain_train) {
eval_list <- append(eval_list, booster$eval_train(feval = feval))
# Get evaluation results with passed-in functions
for (eval_function in eval_functions) {

# Validation has training dataset?
if (valid_contain_train) {
eval_list <- append(eval_list, booster$eval_train(feval = eval_function))
}

eval_list <- append(eval_list, booster$eval_valid(feval = eval_function))
}

# Calling booster$eval_valid() will get
# evaluation results with the metrics in params$metric by calling LGBM_BoosterGetEval_R",
# so need to be sure that gets called, which it wouldn't be above if no functions
# were passed in
if (length(eval_functions) == 0L) {
if (valid_contain_train) {
eval_list <- append(eval_list, booster$eval_train(feval = eval_function))
}
eval_list <- append(eval_list, booster$eval_valid(feval = eval_function))
}

# Has no validation dataset
eval_list <- append(eval_list, booster$eval_valid(feval = feval))
}

# Write evaluation result in environment
Expand All @@ -312,7 +333,7 @@ lgb.train <- function(params = list(),

# when using a custom eval function, the metric name is returned from the
# function, so figure it out from record_evals
if (!is.null(feval)) {
if (!is.null(eval_functions[1L])) {
first_metric <- names(booster$record_evals[[first_valid_name]])[1L]
} else {
first_metric <- booster$.__enclos_env__$private$eval_names[1L]
Expand Down
51 changes: 51 additions & 0 deletions R-package/R/lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,61 @@
#' and one metric. If there's more than one, will check all of them
#' except the training data. Returns the model with (best_iter + early_stopping_rounds).
#' If early stopping occurs, the model will have 'best_iter' field.
#' @param eval evaluation function(s). This can be a character vector, function, or list with a mixture of
#' strings and functions.
#'
#' \itemize{
#' \item{\bold{a. character vector}:
#' If you provide a character vector to this argument, it should contain strings with valid
#' evaluation metrics.
#' See \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#metric}{
#' The "metric" section of the documentation}
#' for a list of valid metrics.
#' }
#' \item{\bold{b. function}:
#' You can provide a custom evaluation function. This
#' should accept the keyword arguments \code{preds} and \code{dtrain} and should return a named
#' list with three elements:
#' \itemize{
#' \item{\code{name}: A string with the name of the metric, used for printing
#' and storing results.
#' }
#' \item{\code{value}: A single number indicating the value of the metric for the
#' given predictions and true values
#' }
#' \item{
#' \code{higher_better}: A boolean indicating whether higher values indicate a better fit.
#' For example, this would be \code{FALSE} for metrics like MAE or RMSE.
#' }
#' }
#' }
#' \item{\bold{c. list}:
#' If a list is given, it should only contain character vectors and functions.
#' These should follow the requirements from the descriptions above.
#' }
#' }
#' @param eval_freq evaluation output frequency, only effect when verbose > 0
#' @param init_model path of model file of \code{lgb.Booster} object, will continue training from this model
#' @param nrounds number of training rounds
#' @param obj objective function, can be character or custom objective function. Examples include
#' \code{regression}, \code{regression_l1}, \code{huber},
#' \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}
#' @param params List of parameters
#' @param verbose verbosity for output, if <= 0, also will disable the print of evaluation during training
#' @section Early Stopping:
#'
#' "early stopping" refers to stopping the training process if the model's performance on a given
#' validation set does not improve for several consecutive iterations.
#'
#' If multiple arguments are given to \code{eval}, their order will be preserved. If you enable
#' early stopping by setting \code{early_stopping_rounds} in \code{params}, by default all
#' metrics will be considered for early stopping.
#'
#' If you want to only consider the first metric for early stopping, pass
#' \code{first_metric_only = TRUE} in \code{params}. Note that if you also specify \code{metric}
#' in \code{params}, that metric will be considered the "first" one. If you omit \code{metric},
#' a default metric will be used based on your choice for the parameter \code{obj} (keyword argument)
#' or \code{objective} (passed into \code{params}).
#' @keywords internal
NULL

Expand Down Expand Up @@ -47,6 +97,7 @@ NULL
#' the number of real CPU cores, not the number of threads (most
#' CPU using hyper-threading to generate 2 threads per CPU core).}
#' }
#' @inheritSection lgb_shared_params Early Stopping
#' @export
lightgbm <- function(data,
label = NULL,
Expand Down
33 changes: 25 additions & 8 deletions R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,10 @@ lgb.check.obj <- function(params, obj) {
}

# [description]
# make sure that "metric" is populated on params,
# and add any eval values to it
# [return]
# params, where "metric" is a list
# Take any character values from eval and store them in params$metric.
# This has to account for the fact that `eval` could be a character vector,
# a function, a list of functions, or a list with a mix of strings and
# functions
lgb.check.eval <- function(params, eval) {

if (is.null(params$metric)) {
Expand All @@ -330,13 +330,30 @@ lgb.check.eval <- function(params, eval) {
params$metric <- as.list(params$metric)
}

if (is.character(eval)) {
params$metric <- append(params$metric, eval)
# if 'eval' is a character vector or list, find the character
# elements and add them to 'metric'
if (!is.function(eval)) {
for (i in seq_along(eval)) {
element <- eval[[i]]
if (is.character(element)) {
params$metric <- append(params$metric, element)
}
}
}

if (identical(class(eval), "list")) {
params$metric <- append(params$metric, unlist(eval))
# If more than one character metric was given, then "None" should
# not be included
if (length(params$metric) > 1L) {
params$metric <- Filter(
f = function(metric) {
!(metric %in% .NO_METRIC_STRINGS())
}
, x = params$metric
)
}

# duplicate metrics should be filtered out
params$metric <- as.list(unique(unlist(params$metric)))

return(params)
}
Loading