Fix loading old model. (#5724) (#5737)

* Add test.
This commit is contained in:
Jiaming Yuan 2020-06-01 04:32:24 +08:00 committed by GitHub
parent e74560c86a
commit 8467880aeb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 7 deletions

View File

@ -689,15 +689,23 @@ class LearnerIO : public LearnerConfiguration {
warn_old_model = false; warn_old_model = false;
} }
if (mparam_.major_version >= 1) { if (mparam_.major_version < 1) {
learner_model_param_ = LearnerModelParam(mparam_,
obj_->ProbToMargin(mparam_.base_score));
} else {
// Before 1.0.0, base_score is saved as a transformed value, and there's no version // Before 1.0.0, base_score is saved as a transformed value, and there's no version
// attribute in the saved model. // attribute (saved a 0) in the saved model.
learner_model_param_ = LearnerModelParam(mparam_, mparam_.base_score); std::string multi{"multi:"};
if (!std::equal(multi.cbegin(), multi.cend(), tparam_.objective.cbegin())) {
HostDeviceVector<float> t;
t.HostVector().resize(1);
t.HostVector().at(0) = mparam_.base_score;
this->obj_->PredTransform(&t);
auto base_score = t.HostVector().at(0);
mparam_.base_score = base_score;
}
warn_old_model = true; warn_old_model = true;
} }
learner_model_param_ =
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score));
if (attributes_.find("objective") != attributes_.cend()) { if (attributes_.find("objective") != attributes_.cend()) {
auto obj_str = attributes_.at("objective"); auto obj_str = attributes_.at("objective");
auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()}); auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()});

View File

@ -1,3 +1,4 @@
[pytest] [pytest]
markers = markers =
mgpu: Mark a test that requires multiple GPUs to run. mgpu: Mark a test that requires multiple GPUs to run.
ci: Mark a test that runs only on CI.

View File

@ -4,6 +4,7 @@ import generate_models as gm
import json import json
import zipfile import zipfile
import pytest import pytest
import copy
def run_model_param_check(config): def run_model_param_check(config):
@ -124,6 +125,9 @@ def test_model_compatibility():
if name.startswith('xgboost-'): if name.startswith('xgboost-'):
booster = xgboost.Booster(model_file=path) booster = xgboost.Booster(model_file=path)
run_booster_check(booster, name) run_booster_check(booster, name)
# Do full serialization.
booster = copy.copy(booster)
run_booster_check(booster, name)
elif name.startswith('xgboost_scikit'): elif name.startswith('xgboost_scikit'):
run_scikit_model_check(name, path) run_scikit_model_check(name, path)
else: else: