[R] Add a compatibility layer to load Booster object from an old RDS file (#5940)
* [R] Add a compatibility layer to load Booster from an old RDS * Modify QuantileHistMaker::LoadConfig() to be backward compatible with 1.1.x * Add a big warning about compatibility in QuantileHistMaker::LoadConfig() * Add testing suite * Discourage use of saveRDS() in CRAN doc
This commit is contained in:
parent
40361043ae
commit
ace7fd328b
@ -308,6 +308,20 @@ xgb.createFolds <- function(y, k = 10)
|
||||
#' @name xgboost-deprecated
|
||||
NULL
|
||||
|
||||
#' Do not use saveRDS() for long-term archival of models. Use xgb.save() instead.
|
||||
#'
|
||||
#' It is a common practice to use the built-in \code{saveRDS()} function to persist R objects to
|
||||
#' the disk. While \code{xgb.Booster} objects can be persisted with \code{saveRDS()} as well, it
|
||||
#' is not advisable to use it if the model is to be accessed in the future. If you train a model
|
||||
#' with the current version of XGBoost and persist it with \code{saveRDS()}, the model is not
|
||||
#' guaranteed to be accessible in later releases of XGBoost. To ensure that your model can be
|
||||
#' accessed in future releases of XGBoost, use \code{xgb.save()} instead. For more details and
|
||||
#' explanation, consult the page
|
||||
#' \url{https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html}.
|
||||
#'
|
||||
#' @name a-compatibility-note-for-saveRDS
|
||||
NULL
|
||||
|
||||
# Lookup table for the deprecated parameters bookkeeping
|
||||
depr_par_lut <- matrix(c(
|
||||
'print.every.n', 'print_every_n',
|
||||
|
||||
@ -6,7 +6,26 @@
|
||||
xgb.unserialize <- function(buffer) {
|
||||
cachelist <- list()
|
||||
handle <- .Call(XGBoosterCreate_R, cachelist)
|
||||
.Call(XGBoosterUnserializeFromBuffer_R, handle, buffer)
|
||||
tryCatch(
|
||||
.Call(XGBoosterUnserializeFromBuffer_R, handle, buffer),
|
||||
error = function(e) {
|
||||
error_msg <- conditionMessage(e)
|
||||
m <- regexec("(src[\\\\/]learner.cc:[0-9]+): Check failed: (header == serialisation_header_)",
|
||||
error_msg, perl = TRUE)
|
||||
groups <- regmatches(error_msg, m)[[1]]
|
||||
if (length(groups) == 3) {
|
||||
warning(paste("The model had been generated by XGBoost version 1.0.0 or earlier and was ",
|
||||
"loaded from a RDS file. We strongly ADVISE AGAINST using saveRDS() ",
|
||||
"function, to ensure that your model can be read in current and upcoming ",
|
||||
"XGBoost releases. Please use xgb.save() instead to preserve models for the ",
|
||||
"long term. For more details and explanation, see ",
|
||||
"https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html",
|
||||
sep = ""))
|
||||
.Call(XGBoosterLoadModelFromRaw_R, handle, buffer)
|
||||
} else {
|
||||
stop(e)
|
||||
}
|
||||
})
|
||||
class(handle) <- "xgb.Booster.handle"
|
||||
return (handle)
|
||||
}
|
||||
|
||||
15
R-package/man/a-compatibility-note-for-saveRDS.Rd
Normal file
15
R-package/man/a-compatibility-note-for-saveRDS.Rd
Normal file
@ -0,0 +1,15 @@
|
||||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/utils.R
|
||||
\name{a-compatibility-note-for-saveRDS}
|
||||
\alias{a-compatibility-note-for-saveRDS}
|
||||
\title{Do not use saveRDS() for long-term archival of models. Use xgb.save() instead.}
|
||||
\description{
|
||||
It is a common practice to use the built-in \code{saveRDS()} function to persist R objects to
|
||||
the disk. While \code{xgb.Booster} objects can be persisted with \code{saveRDS()} as well, it
|
||||
is not advisable to use it if the model is to be accessed in the future. If you train a model
|
||||
with the current version of XGBoost and persist it with \code{saveRDS()}, the model is not
|
||||
guaranteed to be accessible in later releases of XGBoost. To ensure that your model can be
|
||||
accessed in future releases of XGBoost, use \code{xgb.save()} instead. For more details and
|
||||
explanation, consult the page
|
||||
\url{https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html}.
|
||||
}
|
||||
@ -375,7 +375,7 @@ SEXP XGBoosterSaveJsonConfig_R(SEXP handle) {
|
||||
|
||||
SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value) {
|
||||
R_API_BEGIN();
|
||||
XGBoosterLoadJsonConfig(R_ExternalPtrAddr(handle), CHAR(asChar(value)));
|
||||
CHECK_CALL(XGBoosterLoadJsonConfig(R_ExternalPtrAddr(handle), CHAR(asChar(value))));
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
@ -397,9 +397,9 @@ SEXP XGBoosterSerializeToBuffer_R(SEXP handle) {
|
||||
|
||||
SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw) {
|
||||
R_API_BEGIN();
|
||||
XGBoosterUnserializeFromBuffer(R_ExternalPtrAddr(handle),
|
||||
CHECK_CALL(XGBoosterUnserializeFromBuffer(R_ExternalPtrAddr(handle),
|
||||
RAW(raw),
|
||||
length(raw));
|
||||
length(raw)));
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
94
R-package/tests/generate_models.R
Normal file
94
R-package/tests/generate_models.R
Normal file
@ -0,0 +1,94 @@
|
||||
# Script to generate reference models. The reference models are used to test backward compatibility
|
||||
# of saved model files from XGBoost version 0.90 and 1.0.x.
|
||||
library(xgboost)
|
||||
library(Matrix)
|
||||
source('./generate_models_params.R')
|
||||
|
||||
set.seed(0)
|
||||
metadata <- model_generator_metadata()
|
||||
X <- Matrix(data = rnorm(metadata$kRows * metadata$kCols), nrow = metadata$kRows,
|
||||
ncol = metadata$kCols, sparse = TRUE)
|
||||
w <- runif(metadata$kRows)
|
||||
|
||||
version <- packageVersion('xgboost')
|
||||
target_dir <- 'models'
|
||||
|
||||
save_booster <- function (booster, model_name) {
|
||||
booster_bin <- function (model_name) {
|
||||
return (file.path(target_dir, paste('xgboost-', version, '.', model_name, '.bin', sep = '')))
|
||||
}
|
||||
booster_json <- function (model_name) {
|
||||
return (file.path(target_dir, paste('xgboost-', version, '.', model_name, '.json', sep = '')))
|
||||
}
|
||||
booster_rds <- function (model_name) {
|
||||
return (file.path(target_dir, paste('xgboost-', version, '.', model_name, '.rds', sep = '')))
|
||||
}
|
||||
xgb.save(booster, booster_bin(model_name))
|
||||
saveRDS(booster, booster_rds(model_name))
|
||||
if (version >= '1.0.0') {
|
||||
xgb.save(booster, booster_json(model_name))
|
||||
}
|
||||
}
|
||||
|
||||
generate_regression_model <- function () {
|
||||
print('Regression')
|
||||
y <- rnorm(metadata$kRows)
|
||||
|
||||
data <- xgb.DMatrix(X, label = y)
|
||||
params <- list(tree_method = 'hist', num_parallel_tree = metadata$kForests,
|
||||
max_depth = metadata$kMaxDepth)
|
||||
booster <- xgb.train(params, data, nrounds = metadata$kRounds)
|
||||
save_booster(booster, 'reg')
|
||||
}
|
||||
|
||||
generate_logistic_model <- function () {
|
||||
print('Binary classification with logistic loss')
|
||||
y <- sample(0:1, size = metadata$kRows, replace = TRUE)
|
||||
stopifnot(max(y) == 1, min(y) == 0)
|
||||
|
||||
data <- xgb.DMatrix(X, label = y, weight = w)
|
||||
params <- list(tree_method = 'hist', num_parallel_tree = metadata$kForests,
|
||||
max_depth = metadata$kMaxDepth, objective = 'binary:logistic')
|
||||
booster <- xgb.train(params, data, nrounds = metadata$kRounds)
|
||||
save_booster(booster, 'logit')
|
||||
}
|
||||
|
||||
generate_classification_model <- function () {
|
||||
print('Multi-class classification')
|
||||
y <- sample(0:(metadata$kClasses - 1), size = metadata$kRows, replace = TRUE)
|
||||
stopifnot(max(y) == metadata$kClasses - 1, min(y) == 0)
|
||||
|
||||
data <- xgb.DMatrix(X, label = y, weight = w)
|
||||
params <- list(num_class = metadata$kClasses, tree_method = 'hist',
|
||||
num_parallel_tree = metadata$kForests, max_depth = metadata$kMaxDepth,
|
||||
objective = 'multi:softmax')
|
||||
booster <- xgb.train(params, data, nrounds = metadata$kRounds)
|
||||
save_booster(booster, 'cls')
|
||||
}
|
||||
|
||||
generate_ranking_model <- function () {
|
||||
print('Learning to rank')
|
||||
y <- sample(0:4, size = metadata$kRows, replace = TRUE)
|
||||
stopifnot(max(y) == 4, min(y) == 0)
|
||||
kGroups <- 20
|
||||
w <- runif(kGroups)
|
||||
g <- rep(50, times = kGroups)
|
||||
|
||||
data <- xgb.DMatrix(X, label = y, group = g)
|
||||
# setinfo(data, 'weight', w)
|
||||
# ^^^ does not work in version <= 1.1.0; see https://github.com/dmlc/xgboost/issues/5942
|
||||
# So call low-level function XGDMatrixSetInfo_R directly. Since this function is not an exported
|
||||
# symbol, use the triple-colon operator.
|
||||
.Call(xgboost:::XGDMatrixSetInfo_R, data, 'weight', as.numeric(w))
|
||||
params <- list(objective = 'rank:ndcg', num_parallel_tree = metadata$kForests,
|
||||
tree_method = 'hist', max_depth = metadata$kMaxDepth)
|
||||
booster <- xgb.train(params, data, nrounds = metadata$kRounds)
|
||||
save_booster(booster, 'ltr')
|
||||
}
|
||||
|
||||
dir.create(target_dir)
|
||||
|
||||
invisible(generate_regression_model())
|
||||
invisible(generate_logistic_model())
|
||||
invisible(generate_classification_model())
|
||||
invisible(generate_ranking_model())
|
||||
10
R-package/tests/generate_models_params.R
Normal file
10
R-package/tests/generate_models_params.R
Normal file
@ -0,0 +1,10 @@
|
||||
model_generator_metadata <- function() {
|
||||
return (list(
|
||||
kRounds = 2,
|
||||
kRows = 1000,
|
||||
kCols = 4,
|
||||
kForests = 2,
|
||||
kMaxDepth = 2,
|
||||
kClasses = 3
|
||||
))
|
||||
}
|
||||
@ -1,4 +1,4 @@
|
||||
library(testthat)
|
||||
library(xgboost)
|
||||
|
||||
test_check("xgboost")
|
||||
test_check("xgboost", reporter = ProgressReporter)
|
||||
|
||||
@ -2,7 +2,7 @@ context("Code is of high quality and lint free")
|
||||
test_that("Code Lint", {
|
||||
skip_on_cran()
|
||||
my_linters <- list(
|
||||
absolute_paths_linter = lintr::absolute_paths_linter,
|
||||
absolute_path_linter = lintr::absolute_path_linter,
|
||||
assignment_linter = lintr::assignment_linter,
|
||||
closed_curly_linter = lintr::closed_curly_linter,
|
||||
commas_linter = lintr::commas_linter,
|
||||
|
||||
77
R-package/tests/testthat/test_model_compatibility.R
Normal file
77
R-package/tests/testthat/test_model_compatibility.R
Normal file
@ -0,0 +1,77 @@
|
||||
require(xgboost)
|
||||
require(jsonlite)
|
||||
source('../generate_models_params.R')
|
||||
|
||||
context("Models from previous versions of XGBoost can be loaded")
|
||||
|
||||
metadata <- model_generator_metadata()
|
||||
|
||||
run_model_param_check <- function (config) {
|
||||
expect_equal(config$learner$learner_model_param$num_feature, '4')
|
||||
expect_equal(config$learner$learner_train_param$booster, 'gbtree')
|
||||
}
|
||||
|
||||
get_num_tree <- function (booster) {
|
||||
dump <- xgb.dump(booster)
|
||||
m <- regexec('booster\\[[0-9]+\\]', dump, perl = TRUE)
|
||||
m <- regmatches(dump, m)
|
||||
num_tree <- Reduce('+', lapply(m, length))
|
||||
return (num_tree)
|
||||
}
|
||||
|
||||
run_booster_check <- function (booster, name) {
|
||||
# If given a handle, we need to call xgb.Booster.complete() prior to using xgb.config().
|
||||
if (inherits(booster, "xgb.Booster") && xgboost:::is.null.handle(booster$handle)) {
|
||||
booster <- xgb.Booster.complete(booster)
|
||||
}
|
||||
config <- jsonlite::fromJSON(xgb.config(booster))
|
||||
run_model_param_check(config)
|
||||
if (name == 'cls') {
|
||||
expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds * metadata$kClasses)
|
||||
expect_equal(as.numeric(config$learner$learner_model_param$base_score), 0.5)
|
||||
expect_equal(config$learner$learner_train_param$objective, 'multi:softmax')
|
||||
expect_equal(as.numeric(config$learner$learner_model_param$num_class), metadata$kClasses)
|
||||
} else if (name == 'logit') {
|
||||
expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
|
||||
expect_equal(as.numeric(config$learner$learner_model_param$num_class), 0)
|
||||
expect_equal(config$learner$learner_train_param$objective, 'binary:logistic')
|
||||
} else if (name == 'ltr') {
|
||||
expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
|
||||
expect_equal(config$learner$learner_train_param$objective, 'rank:ndcg')
|
||||
} else {
|
||||
expect_equal(name, 'reg')
|
||||
expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
|
||||
expect_equal(as.numeric(config$learner$learner_model_param$base_score), 0.5)
|
||||
expect_equal(config$learner$learner_train_param$objective, 'reg:squarederror')
|
||||
}
|
||||
}
|
||||
|
||||
test_that("Models from previous versions of XGBoost can be loaded", {
|
||||
bucket <- 'xgboost-ci-jenkins-artifacts'
|
||||
region <- 'us-west-2'
|
||||
file_name <- 'xgboost_r_model_compatibility_test.zip'
|
||||
zipfile <- file.path(getwd(), file_name)
|
||||
model_dir <- file.path(getwd(), 'models')
|
||||
download.file(paste('https://', bucket, '.s3-', region, '.amazonaws.com/', file_name, sep = ''),
|
||||
destfile = zipfile, mode = 'wb')
|
||||
unzip(zipfile, overwrite = TRUE)
|
||||
|
||||
pred_data <- xgb.DMatrix(matrix(c(0, 0, 0, 0), nrow = 1, ncol = 4))
|
||||
|
||||
lapply(list.files(model_dir), function (x) {
|
||||
model_file <- file.path(model_dir, x)
|
||||
m <- regexec("xgboost-([0-9\\.]+)\\.([a-z]+)\\.[a-z]+", model_file, perl = TRUE)
|
||||
m <- regmatches(model_file, m)[[1]]
|
||||
model_xgb_ver <- m[2]
|
||||
name <- m[3]
|
||||
|
||||
if (endsWith(model_file, '.rds')) {
|
||||
booster <- readRDS(model_file)
|
||||
} else {
|
||||
booster <- xgb.load(model_file)
|
||||
}
|
||||
predict(booster, newdata = pred_data)
|
||||
run_booster_check(booster, name)
|
||||
})
|
||||
expect_true(TRUE)
|
||||
})
|
||||
@ -125,7 +125,22 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
FromJson(config.at("train_param"), &this->param_);
|
||||
FromJson(config.at("cpu_hist_train_param"), &this->hist_maker_param_);
|
||||
try {
|
||||
FromJson(config.at("cpu_hist_train_param"), &this->hist_maker_param_);
|
||||
} catch (std::out_of_range& e) {
|
||||
// XGBoost model is from 1.1.x, so 'cpu_hist_train_param' is missing.
|
||||
// We add this compatibility check because it's just recently that we (developers) began
|
||||
// persuade R users away from using saveRDS() for model serialization. Hopefully, one day,
|
||||
// everyone will be using xgb.save().
|
||||
LOG(WARNING) << "Attempted to load interal configuration for a model file that was generated "
|
||||
<< "by a previous version of XGBoost. A likely cause for this warning is that the model "
|
||||
<< "was saved with saveRDS() in R or pickle.dump() in Python. We strongly ADVISE AGAINST "
|
||||
<< "using saveRDS() or pickle.dump() so that the model remains accessible in current and "
|
||||
<< "upcoming XGBoost releases. Please use xgb.save() instead to preserve models for the "
|
||||
<< "long term. For more details and explanation, see "
|
||||
<< "https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html";
|
||||
this->hist_maker_param_.UpdateAllowUnknown(Args{});
|
||||
}
|
||||
}
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user