diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index 311d3f2b910c..f4b729408801 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -26,108 +26,90 @@ Booster <- R6::R6Class( modelfile = NULL, model_str = NULL) { - # Create parameters and handle handle <- NULL - # Attempts to create a handle for the dataset - try({ - - # Check if training dataset is not null - if (!is.null(train_set)) { - # Check if training dataset is lgb.Dataset or not - if (!lgb.is.Dataset(train_set)) { - stop("lgb.Booster: Can only use lgb.Dataset as training data") - } - train_set_handle <- train_set$.__enclos_env__$private$get_handle() - params <- utils::modifyList(params, train_set$get_params()) - params_str <- lgb.params2str(params = params) - # Store booster handle - handle <- .Call( - LGBM_BoosterCreate_R - , train_set_handle - , params_str - ) - - # Create private booster information - private$train_set <- train_set - private$train_set_version <- train_set$.__enclos_env__$private$version - private$num_dataset <- 1L - private$init_predictor <- train_set$.__enclos_env__$private$predictor - - # Check if predictor is existing - if (!is.null(private$init_predictor)) { - - # Merge booster - .Call( - LGBM_BoosterMerge_R - , handle - , private$init_predictor$.__enclos_env__$private$handle - ) - - } - - # Check current iteration - private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE) + if (!is.null(train_set)) { - } else if (!is.null(modelfile)) { + if (!lgb.is.Dataset(train_set)) { + stop("lgb.Booster: Can only use lgb.Dataset as training data") + } + train_set_handle <- train_set$.__enclos_env__$private$get_handle() + params <- utils::modifyList(params, train_set$get_params()) + params_str <- lgb.params2str(params = params) + # Store booster handle + handle <- .Call( + LGBM_BoosterCreate_R + , train_set_handle + , params_str + ) - # Do we have a model file as character? - if (!is.character(modelfile)) { - stop("lgb.Booster: Can only use a string as model file path") - } + # Create private booster information + private$train_set <- train_set + private$train_set_version <- train_set$.__enclos_env__$private$version + private$num_dataset <- 1L + private$init_predictor <- train_set$.__enclos_env__$private$predictor - modelfile <- path.expand(modelfile) + if (!is.null(private$init_predictor)) { - # Create booster from model - handle <- .Call( - LGBM_BoosterCreateFromModelfile_R - , modelfile + # Merge booster + .Call( + LGBM_BoosterMerge_R + , handle + , private$init_predictor$.__enclos_env__$private$handle ) - } else if (!is.null(model_str)) { + } - # Do we have a model_str as character/raw? - if (!is.raw(model_str) && !is.character(model_str)) { - stop("lgb.Booster: Can only use a character/raw vector as model_str") - } + # Check current iteration + private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE) - # Create booster from model - handle <- .Call( - LGBM_BoosterLoadModelFromString_R - , model_str - ) + } else if (!is.null(modelfile)) { - } else { + # Do we have a model file as character? + if (!is.character(modelfile)) { + stop("lgb.Booster: Can only use a string as model file path") + } - # Booster non existent - stop( - "lgb.Booster: Need at least either training dataset, " - , "model file, or model_str to create booster instance" - ) + modelfile <- path.expand(modelfile) - } + # Create booster from model + handle <- .Call( + LGBM_BoosterCreateFromModelfile_R + , modelfile + ) - }) + } else if (!is.null(model_str)) { - # Check whether the handle was created properly if it was not stopped earlier by a stop call - if (isTRUE(lgb.is.null.handle(x = handle))) { + # Do we have a model_str as character/raw? + if (!is.raw(model_str) && !is.character(model_str)) { + stop("lgb.Booster: Can only use a character/raw vector as model_str") + } - stop("lgb.Booster: cannot create Booster handle") + # Create booster from model + handle <- .Call( + LGBM_BoosterLoadModelFromString_R + , model_str + ) } else { - # Create class - class(handle) <- "lgb.Booster.handle" - private$handle <- handle - private$num_class <- 1L - .Call( - LGBM_BoosterGetNumClasses_R - , private$handle - , private$num_class + # Booster non existent + stop( + "lgb.Booster: Need at least either training dataset, " + , "model file, or model_str to create booster instance" ) } + class(handle) <- "lgb.Booster.handle" + private$handle <- handle + private$num_class <- 1L + .Call( + LGBM_BoosterGetNumClasses_R + , private$handle + , private$num_class + ) + self$params <- params return(invisible(NULL)) diff --git a/R-package/tests/testthat/test_lgb.Booster.R b/R-package/tests/testthat/test_lgb.Booster.R index 6984096064c2..9929c0ee6abc 100644 --- a/R-package/tests/testthat/test_lgb.Booster.R +++ b/R-package/tests/testthat/test_lgb.Booster.R @@ -947,7 +947,76 @@ test_that("Booster$new() using a Dataset with a null handle should raise an info verbose = VERBOSITY ) ) - }, regexp = "lgb.Booster: cannot create Booster handle") + }, regexp = "Attempting to create a Dataset without any raw data") +}) + +test_that("Booster$new() raises informative errors for malformed inputs", { + data(agaricus.train, package = "lightgbm") + train <- agaricus.train + dtrain <- lgb.Dataset(train$data, label = train$label) + + # no inputs + expect_error({ + Booster$new() + }, regexp = "lgb.Booster: Need at least either training dataset, model file, or model_str") + + # unrecognized objective + expect_error({ + Booster$new( + params = list(objective = "not_a_real_objective") + , train_set = dtrain + ) + }, regexp = "Unknown objective type name: not_a_real_objective") + + # train_set is not a Dataset + expect_error({ + Booster$new( + train_set = data.table::data.table(rnorm(1L:10L)) + ) + }, regexp = "lgb.Booster: Can only use lgb.Dataset as training data") + + # model file isn't a string + expect_error({ + Booster$new( + modelfile = list() + ) + }, regexp = "lgb.Booster: Can only use a string as model file path") + + # model file doesn't exist + expect_error({ + Booster$new( + params = list() + , modelfile = "file-that-does-not-exist.model" + ) + }, regexp = "Could not open file-that-does-not-exist.model") + + # model file doesn't contain a valid LightGBM model + model_file <- tempfile(fileext = ".model") + writeLines( + text = c("make", "good", "predictions") + , con = model_file + ) + expect_error({ + Booster$new( + params = list() + , modelfile = model_file + ) + }, regexp = "Unknown model format or submodel type in model file") + + # malformed model string + expect_error({ + Booster$new( + params = list() + , model_str = "a\nb\n" + ) + }, regexp = "Model file doesn't specify the number of classes") + + # model string isn't character or raw + expect_error({ + Booster$new( + model_str = numeric() + ) + }, regexp = "lgb.Booster: Can only use a character/raw vector as model_str") }) # this is almost identical to the test above it, but for lgb.cv(). A lot of code