Restore attributes in complete. (#5573)

This commit is contained in:
Jiaming Yuan 2020-04-22 02:06:55 +08:00 committed by GitHub
parent a734f52807
commit 564b22cee5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 15 deletions

View File

@ -131,6 +131,25 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
object$raw <- xgb.serialize(object$handle) object$raw <- xgb.serialize(object$handle)
} }
} }
attrs <- xgb.attributes(object)
if (!is.null(attrs$best_ntreelimit)) {
object$best_ntreelimit <- as.integer(attrs$best_ntreelimit)
}
if (!is.null(attrs$best_iteration)) {
## Convert from 0 based back to 1 based.
object$best_iteration <- as.integer(attrs$best_iteration) + 1
}
if (!is.null(attrs$best_score)) {
object$best_score <- as.numeric(attrs$best_score)
}
if (!is.null(attrs$best_msg)) {
object$best_msg <- attrs$best_msg
}
if (!is.null(attrs$niter)) {
object$niter <- as.integer(attrs$niter)
}
return(object) return(object)
} }

View File

@ -1,30 +1,30 @@
#' Load xgboost model from binary file #' Load xgboost model from binary file
#' #'
#' Load xgboost model from the binary model file. #' Load xgboost model from the binary model file.
#' #'
#' @param modelfile the name of the binary input file. #' @param modelfile the name of the binary input file.
#' #'
#' @details #' @details
#' The input file is expected to contain a model saved in an xgboost-internal binary format #' The input file is expected to contain a model saved in an xgboost-internal binary format
#' using either \code{\link{xgb.save}} or \code{\link{cb.save.model}} in R, or using some #' using either \code{\link{xgb.save}} or \code{\link{cb.save.model}} in R, or using some
#' appropriate methods from other xgboost interfaces. E.g., a model trained in Python and #' appropriate methods from other xgboost interfaces. E.g., a model trained in Python and
#' saved from there in xgboost format, could be loaded from R. #' saved from there in xgboost format, could be loaded from R.
#' #'
#' Note: a model saved as an R-object, has to be loaded using corresponding R-methods, #' Note: a model saved as an R-object, has to be loaded using corresponding R-methods,
#' not \code{xgb.load}. #' not \code{xgb.load}.
#' #'
#' @return #' @return
#' An object of \code{xgb.Booster} class. #' An object of \code{xgb.Booster} class.
#' #'
#' @seealso #' @seealso
#' \code{\link{xgb.save}}, \code{\link{xgb.Booster.complete}}. #' \code{\link{xgb.save}}, \code{\link{xgb.Booster.complete}}.
#' #'
#' @examples #' @examples
#' data(agaricus.train, package='xgboost') #' data(agaricus.train, package='xgboost')
#' data(agaricus.test, package='xgboost') #' data(agaricus.test, package='xgboost')
#' train <- agaricus.train #' train <- agaricus.train
#' test <- agaricus.test #' test <- agaricus.test
#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2, #' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
#' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic") #' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
#' xgb.save(bst, 'xgb.model') #' xgb.save(bst, 'xgb.model')
#' bst <- xgb.load('xgb.model') #' bst <- xgb.load('xgb.model')

View File

@ -222,6 +222,15 @@ test_that("early stopping xgb.train works", {
early_stopping_rounds = 3, maximize = FALSE, verbose = 0) early_stopping_rounds = 3, maximize = FALSE, verbose = 0)
) )
expect_equal(bst$evaluation_log, bst0$evaluation_log) expect_equal(bst$evaluation_log, bst0$evaluation_log)
xgb.save(bst, "model.bin")
loaded <- xgb.load("model.bin")
expect_false(is.null(loaded$best_iteration))
expect_equal(loaded$best_iteration, bst$best_ntreelimit)
expect_equal(loaded$best_ntreelimit, bst$best_ntreelimit)
file.remove("model.bin")
}) })
test_that("early stopping using a specific metric works", { test_that("early stopping using a specific metric works", {