diff --git a/R-package/R/xgb.Booster.R b/R-package/R/xgb.Booster.R index 41475a1d494f..dcc446985d92 100644 --- a/R-package/R/xgb.Booster.R +++ b/R-package/R/xgb.Booster.R @@ -131,6 +131,25 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) { object$raw <- xgb.serialize(object$handle) } } + + attrs <- xgb.attributes(object) + if (!is.null(attrs$best_ntreelimit)) { + object$best_ntreelimit <- as.integer(attrs$best_ntreelimit) + } + if (!is.null(attrs$best_iteration)) { + ## Convert from 0 based back to 1 based. + object$best_iteration <- as.integer(attrs$best_iteration) + 1 + } + if (!is.null(attrs$best_score)) { + object$best_score <- as.numeric(attrs$best_score) + } + if (!is.null(attrs$best_msg)) { + object$best_msg <- attrs$best_msg + } + if (!is.null(attrs$niter)) { + object$niter <- as.integer(attrs$niter) + } + return(object) } diff --git a/R-package/R/xgb.load.R b/R-package/R/xgb.load.R index df6a211538c0..bda4e7e0713a 100644 --- a/R-package/R/xgb.load.R +++ b/R-package/R/xgb.load.R @@ -1,30 +1,30 @@ #' Load xgboost model from binary file -#' -#' Load xgboost model from the binary model file. -#' +#' +#' Load xgboost model from the binary model file. +#' #' @param modelfile the name of the binary input file. -#' -#' @details +#' +#' @details #' The input file is expected to contain a model saved in an xgboost-internal binary format -#' using either \code{\link{xgb.save}} or \code{\link{cb.save.model}} in R, or using some -#' appropriate methods from other xgboost interfaces. E.g., a model trained in Python and +#' using either \code{\link{xgb.save}} or \code{\link{cb.save.model}} in R, or using some +#' appropriate methods from other xgboost interfaces. E.g., a model trained in Python and #' saved from there in xgboost format, could be loaded from R. -#' +#' #' Note: a model saved as an R-object, has to be loaded using corresponding R-methods, #' not \code{xgb.load}. -#' -#' @return +#' +#' @return #' An object of \code{xgb.Booster} class. -#' -#' @seealso -#' \code{\link{xgb.save}}, \code{\link{xgb.Booster.complete}}. -#' +#' +#' @seealso +#' \code{\link{xgb.save}}, \code{\link{xgb.Booster.complete}}. +#' #' @examples #' data(agaricus.train, package='xgboost') #' data(agaricus.test, package='xgboost') #' train <- agaricus.train #' test <- agaricus.test -#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2, +#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2, #' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic") #' xgb.save(bst, 'xgb.model') #' bst <- xgb.load('xgb.model') diff --git a/R-package/tests/testthat/test_callbacks.R b/R-package/tests/testthat/test_callbacks.R index 76bcd484d5c5..83c93d0758a9 100644 --- a/R-package/tests/testthat/test_callbacks.R +++ b/R-package/tests/testthat/test_callbacks.R @@ -207,7 +207,7 @@ test_that("early stopping xgb.train works", { early_stopping_rounds = 3, maximize = FALSE) , "Stopping. Best iteration") expect_false(is.null(bst$best_iteration)) - expect_lt(bst$best_iteration, 19) + expect_lt(bst$best_iteration, 4) expect_equal(bst$best_iteration, bst$best_ntreelimit) pred <- predict(bst, dtest) @@ -222,6 +222,16 @@ test_that("early stopping xgb.train works", { early_stopping_rounds = 3, maximize = FALSE, verbose = 0) ) expect_equal(bst$evaluation_log, bst0$evaluation_log) + + xgb.save(bst, "model.bin") + loaded <- xgb.load("model.bin") + + print('Start expecting something.') + expect_false(is.null(loaded$best_iteration)) + expect_equal(loaded$best_iteration, bst$best_ntreelimit) + expect_equal(loaded$best_ntreelimit, bst$best_ntreelimit) + + file.remove("model.bin") }) test_that("early stopping using a specific metric works", {