Allow loading old models from RDS (#7864)

This commit is contained in:
Philip Hyunsu Cho 2022-05-06 22:49:38 -07:00 committed by GitHub
parent 1823db53f2
commit d2bc0f0f08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 27 deletions

View File

@ -77,6 +77,7 @@ test_that("Models from previous versions of XGBoost can be loaded", {
model_xgb_ver <- m[2] model_xgb_ver <- m[2]
name <- m[3] name <- m[3]
is_rds <- endsWith(model_file, '.rds') is_rds <- endsWith(model_file, '.rds')
is_json <- endsWith(model_file, '.json')
cpp_warning <- capture.output({ cpp_warning <- capture.output({
# Expect an R warning when a model is loaded from RDS and it was generated by version < 1.1.x # Expect an R warning when a model is loaded from RDS and it was generated by version < 1.1.x
@ -95,15 +96,13 @@ test_that("Models from previous versions of XGBoost can be loaded", {
run_booster_check(booster, name) run_booster_check(booster, name)
} }
}) })
if (compareVersion(model_xgb_ver, '1.0.0.0') < 0) { cpp_warning <- paste0(cpp_warning, collapse = ' ')
# Expect a C++ warning when a model was generated in version < 1.0.x if (is_rds && compareVersion(model_xgb_ver, '1.1.1.1') >= 0) {
m <- grepl(paste0('.*Loading model from XGBoost < 1\\.0\\.0, consider saving it again for ', # Expect a C++ warning when a model is loaded from RDS and it was generated by old XGBoost`
'improved compatibility.*'), cpp_warning, perl = TRUE) m <- grepl(paste0('.*If you are loading a serialized model ',
expect_true(length(m) > 0 && all(m)) '\\(like pickle in Python, RDS in R\\).*',
} else if (is_rds && model_xgb_ver == '1.1.1.1') { 'for more details about differences between ',
# Expect a C++ warning when a model is loaded from RDS and it was generated by version 1.1.x 'saving model and serializing.*'), cpp_warning, perl = TRUE)
m <- grepl(paste0('.*Attempted to load internal configuration for a model file that was ',
'generated by a previous version of XGBoost.*'), cpp_warning, perl = TRUE)
expect_true(length(m) > 0 && all(m)) expect_true(length(m) > 0 && all(m))
} }
}) })

View File

@ -406,8 +406,14 @@ class LearnerConfiguration : public Learner {
} }
void LoadConfig(Json const& in) override { void LoadConfig(Json const& in) override {
// If configuration is loaded, ensure that the model came from the same version
CHECK(IsA<Object>(in)); CHECK(IsA<Object>(in));
Version::Load(in); auto origin_version = Version::Load(in);
if (!Version::Same(origin_version)) {
LOG(WARNING) << ModelMsg();
return; // skip configuration if version is not matched
}
auto const& learner_parameters = get<Object>(in["learner"]); auto const& learner_parameters = get<Object>(in["learner"]);
FromJson(learner_parameters.at("learner_train_param"), &tparam_); FromJson(learner_parameters.at("learner_train_param"), &tparam_);

View File

@ -249,23 +249,7 @@ class QuantileHistMaker: public TreeUpdater {
void LoadConfig(Json const& in) override { void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in); auto const& config = get<Object const>(in);
FromJson(config.at("train_param"), &this->param_); FromJson(config.at("train_param"), &this->param_);
try { FromJson(config.at("cpu_hist_train_param"), &this->hist_maker_param_);
FromJson(config.at("cpu_hist_train_param"), &this->hist_maker_param_);
} catch (std::out_of_range&) {
// XGBoost model is from 1.1.x, so 'cpu_hist_train_param' is missing.
// We add this compatibility check because it's just recently that we (developers) began
// persuade R users away from using saveRDS() for model serialization. Hopefully, one day,
// everyone will be using xgb.save().
LOG(WARNING)
<< "Attempted to load internal configuration for a model file that was generated "
<< "by a previous version of XGBoost. A likely cause for this warning is that the model "
<< "was saved with saveRDS() in R or pickle.dump() in Python. We strongly ADVISE AGAINST "
<< "using saveRDS() or pickle.dump() so that the model remains accessible in current and "
<< "upcoming XGBoost releases. Please use xgb.save() instead to preserve models for the "
<< "long term. For more details and explanation, see "
<< "https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html";
this->hist_maker_param_.UpdateAllowUnknown(Args{});
}
} }
void SaveConfig(Json* p_out) const override { void SaveConfig(Json* p_out) const override {
auto& out = *p_out; auto& out = *p_out;