Skip to content

Commit

Permalink
Restore attributes in complete.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Apr 21, 2020
1 parent 9c1103e commit fa9bc45
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 16 deletions.
19 changes: 19 additions & 0 deletions R-package/R/xgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,25 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
object$raw <- xgb.serialize(object$handle)
}
}

attrs <- xgb.attributes(object)
if (!is.null(attrs$best_ntreelimit)) {
object$best_ntreelimit <- as.integer(attrs$best_ntreelimit)
}
if (!is.null(attrs$best_iteration)) {
## Convert from 0 based back to 1 based.
object$best_iteration <- as.integer(attrs$best_iteration) + 1
}
if (!is.null(attrs$best_score)) {
object$best_score <- as.numeric(attrs$best_score)
}
if (!is.null(attrs$best_msg)) {
object$best_msg <- attrs$best_msg
}
if (!is.null(attrs$niter)) {
object$niter <- as.integer(attrs$niter)
}

return(object)
}

Expand Down
30 changes: 15 additions & 15 deletions R-package/R/xgb.load.R
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
#' Load xgboost model from binary file
#'
#' Load xgboost model from the binary model file.
#'
#'
#' Load xgboost model from the binary model file.
#'
#' @param modelfile the name of the binary input file.
#'
#' @details
#'
#' @details
#' The input file is expected to contain a model saved in an xgboost-internal binary format
#' using either \code{\link{xgb.save}} or \code{\link{cb.save.model}} in R, or using some
#' appropriate methods from other xgboost interfaces. E.g., a model trained in Python and
#' using either \code{\link{xgb.save}} or \code{\link{cb.save.model}} in R, or using some
#' appropriate methods from other xgboost interfaces. E.g., a model trained in Python and
#' saved from there in xgboost format, could be loaded from R.
#'
#'
#' Note: a model saved as an R-object, has to be loaded using corresponding R-methods,
#' not \code{xgb.load}.
#'
#' @return
#'
#' @return
#' An object of \code{xgb.Booster} class.
#'
#' @seealso
#' \code{\link{xgb.save}}, \code{\link{xgb.Booster.complete}}.
#'
#'
#' @seealso
#' \code{\link{xgb.save}}, \code{\link{xgb.Booster.complete}}.
#'
#' @examples
#' data(agaricus.train, package='xgboost')
#' data(agaricus.test, package='xgboost')
#' train <- agaricus.train
#' test <- agaricus.test
#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
#' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
#' xgb.save(bst, 'xgb.model')
#' bst <- xgb.load('xgb.model')
Expand Down
12 changes: 11 additions & 1 deletion R-package/tests/testthat/test_callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ test_that("early stopping xgb.train works", {
early_stopping_rounds = 3, maximize = FALSE)
, "Stopping. Best iteration")
expect_false(is.null(bst$best_iteration))
expect_lt(bst$best_iteration, 19)
expect_lt(bst$best_iteration, 4)
expect_equal(bst$best_iteration, bst$best_ntreelimit)

pred <- predict(bst, dtest)
Expand All @@ -222,6 +222,16 @@ test_that("early stopping xgb.train works", {
early_stopping_rounds = 3, maximize = FALSE, verbose = 0)
)
expect_equal(bst$evaluation_log, bst0$evaluation_log)

xgb.save(bst, "model.bin")
loaded <- xgb.load("model.bin")

print('Start expecting something.')
expect_false(is.null(loaded$best_iteration))
expect_equal(loaded$best_iteration, bst$best_ntreelimit)
expect_equal(loaded$best_ntreelimit, bst$best_ntreelimit)

file.remove("model.bin")
})

test_that("early stopping using a specific metric works", {
Expand Down

0 comments on commit fa9bc45

Please sign in to comment.