Allow loading old models from RDS (#7864)
This commit is contained in:
parent
1823db53f2
commit
d2bc0f0f08
@ -77,6 +77,7 @@ test_that("Models from previous versions of XGBoost can be loaded", {
|
|||||||
model_xgb_ver <- m[2]
|
model_xgb_ver <- m[2]
|
||||||
name <- m[3]
|
name <- m[3]
|
||||||
is_rds <- endsWith(model_file, '.rds')
|
is_rds <- endsWith(model_file, '.rds')
|
||||||
|
is_json <- endsWith(model_file, '.json')
|
||||||
|
|
||||||
cpp_warning <- capture.output({
|
cpp_warning <- capture.output({
|
||||||
# Expect an R warning when a model is loaded from RDS and it was generated by version < 1.1.x
|
# Expect an R warning when a model is loaded from RDS and it was generated by version < 1.1.x
|
||||||
@ -95,15 +96,13 @@ test_that("Models from previous versions of XGBoost can be loaded", {
|
|||||||
run_booster_check(booster, name)
|
run_booster_check(booster, name)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
if (compareVersion(model_xgb_ver, '1.0.0.0') < 0) {
|
cpp_warning <- paste0(cpp_warning, collapse = ' ')
|
||||||
# Expect a C++ warning when a model was generated in version < 1.0.x
|
if (is_rds && compareVersion(model_xgb_ver, '1.1.1.1') >= 0) {
|
||||||
m <- grepl(paste0('.*Loading model from XGBoost < 1\\.0\\.0, consider saving it again for ',
|
# Expect a C++ warning when a model is loaded from RDS and it was generated by old XGBoost`
|
||||||
'improved compatibility.*'), cpp_warning, perl = TRUE)
|
m <- grepl(paste0('.*If you are loading a serialized model ',
|
||||||
expect_true(length(m) > 0 && all(m))
|
'\\(like pickle in Python, RDS in R\\).*',
|
||||||
} else if (is_rds && model_xgb_ver == '1.1.1.1') {
|
'for more details about differences between ',
|
||||||
# Expect a C++ warning when a model is loaded from RDS and it was generated by version 1.1.x
|
'saving model and serializing.*'), cpp_warning, perl = TRUE)
|
||||||
m <- grepl(paste0('.*Attempted to load internal configuration for a model file that was ',
|
|
||||||
'generated by a previous version of XGBoost.*'), cpp_warning, perl = TRUE)
|
|
||||||
expect_true(length(m) > 0 && all(m))
|
expect_true(length(m) > 0 && all(m))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@ -406,8 +406,14 @@ class LearnerConfiguration : public Learner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void LoadConfig(Json const& in) override {
|
void LoadConfig(Json const& in) override {
|
||||||
|
// If configuration is loaded, ensure that the model came from the same version
|
||||||
CHECK(IsA<Object>(in));
|
CHECK(IsA<Object>(in));
|
||||||
Version::Load(in);
|
auto origin_version = Version::Load(in);
|
||||||
|
|
||||||
|
if (!Version::Same(origin_version)) {
|
||||||
|
LOG(WARNING) << ModelMsg();
|
||||||
|
return; // skip configuration if version is not matched
|
||||||
|
}
|
||||||
|
|
||||||
auto const& learner_parameters = get<Object>(in["learner"]);
|
auto const& learner_parameters = get<Object>(in["learner"]);
|
||||||
FromJson(learner_parameters.at("learner_train_param"), &tparam_);
|
FromJson(learner_parameters.at("learner_train_param"), &tparam_);
|
||||||
|
|||||||
@ -249,23 +249,7 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
void LoadConfig(Json const& in) override {
|
void LoadConfig(Json const& in) override {
|
||||||
auto const& config = get<Object const>(in);
|
auto const& config = get<Object const>(in);
|
||||||
FromJson(config.at("train_param"), &this->param_);
|
FromJson(config.at("train_param"), &this->param_);
|
||||||
try {
|
FromJson(config.at("cpu_hist_train_param"), &this->hist_maker_param_);
|
||||||
FromJson(config.at("cpu_hist_train_param"), &this->hist_maker_param_);
|
|
||||||
} catch (std::out_of_range&) {
|
|
||||||
// XGBoost model is from 1.1.x, so 'cpu_hist_train_param' is missing.
|
|
||||||
// We add this compatibility check because it's just recently that we (developers) began
|
|
||||||
// persuade R users away from using saveRDS() for model serialization. Hopefully, one day,
|
|
||||||
// everyone will be using xgb.save().
|
|
||||||
LOG(WARNING)
|
|
||||||
<< "Attempted to load internal configuration for a model file that was generated "
|
|
||||||
<< "by a previous version of XGBoost. A likely cause for this warning is that the model "
|
|
||||||
<< "was saved with saveRDS() in R or pickle.dump() in Python. We strongly ADVISE AGAINST "
|
|
||||||
<< "using saveRDS() or pickle.dump() so that the model remains accessible in current and "
|
|
||||||
<< "upcoming XGBoost releases. Please use xgb.save() instead to preserve models for the "
|
|
||||||
<< "long term. For more details and explanation, see "
|
|
||||||
<< "https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html";
|
|
||||||
this->hist_maker_param_.UpdateAllowUnknown(Args{});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
void SaveConfig(Json* p_out) const override {
|
void SaveConfig(Json* p_out) const override {
|
||||||
auto& out = *p_out;
|
auto& out = *p_out;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user