parent
e74560c86a
commit
8467880aeb
@ -689,15 +689,23 @@ class LearnerIO : public LearnerConfiguration {
|
|||||||
warn_old_model = false;
|
warn_old_model = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mparam_.major_version >= 1) {
|
if (mparam_.major_version < 1) {
|
||||||
learner_model_param_ = LearnerModelParam(mparam_,
|
|
||||||
obj_->ProbToMargin(mparam_.base_score));
|
|
||||||
} else {
|
|
||||||
// Before 1.0.0, base_score is saved as a transformed value, and there's no version
|
// Before 1.0.0, base_score is saved as a transformed value, and there's no version
|
||||||
// attribute in the saved model.
|
// attribute (saved a 0) in the saved model.
|
||||||
learner_model_param_ = LearnerModelParam(mparam_, mparam_.base_score);
|
std::string multi{"multi:"};
|
||||||
|
if (!std::equal(multi.cbegin(), multi.cend(), tparam_.objective.cbegin())) {
|
||||||
|
HostDeviceVector<float> t;
|
||||||
|
t.HostVector().resize(1);
|
||||||
|
t.HostVector().at(0) = mparam_.base_score;
|
||||||
|
this->obj_->PredTransform(&t);
|
||||||
|
auto base_score = t.HostVector().at(0);
|
||||||
|
mparam_.base_score = base_score;
|
||||||
|
}
|
||||||
warn_old_model = true;
|
warn_old_model = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
learner_model_param_ =
|
||||||
|
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score));
|
||||||
if (attributes_.find("objective") != attributes_.cend()) {
|
if (attributes_.find("objective") != attributes_.cend()) {
|
||||||
auto obj_str = attributes_.at("objective");
|
auto obj_str = attributes_.at("objective");
|
||||||
auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()});
|
auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()});
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
[pytest]
|
[pytest]
|
||||||
markers =
|
markers =
|
||||||
mgpu: Mark a test that requires multiple GPUs to run.
|
mgpu: Mark a test that requires multiple GPUs to run.
|
||||||
|
ci: Mark a test that runs only on CI.
|
||||||
@ -4,6 +4,7 @@ import generate_models as gm
|
|||||||
import json
|
import json
|
||||||
import zipfile
|
import zipfile
|
||||||
import pytest
|
import pytest
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
def run_model_param_check(config):
|
def run_model_param_check(config):
|
||||||
@ -124,6 +125,9 @@ def test_model_compatibility():
|
|||||||
if name.startswith('xgboost-'):
|
if name.startswith('xgboost-'):
|
||||||
booster = xgboost.Booster(model_file=path)
|
booster = xgboost.Booster(model_file=path)
|
||||||
run_booster_check(booster, name)
|
run_booster_check(booster, name)
|
||||||
|
# Do full serialization.
|
||||||
|
booster = copy.copy(booster)
|
||||||
|
run_booster_check(booster, name)
|
||||||
elif name.startswith('xgboost_scikit'):
|
elif name.startswith('xgboost_scikit'):
|
||||||
run_scikit_model_check(name, path)
|
run_scikit_model_check(name, path)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user