@@ -689,15 +689,23 @@ class LearnerIO : public LearnerConfiguration {
|
||||
warn_old_model = false;
|
||||
}
|
||||
|
||||
if (mparam_.major_version >= 1) {
|
||||
learner_model_param_ = LearnerModelParam(mparam_,
|
||||
obj_->ProbToMargin(mparam_.base_score));
|
||||
} else {
|
||||
if (mparam_.major_version < 1) {
|
||||
// Before 1.0.0, base_score is saved as a transformed value, and there's no version
|
||||
// attribute in the saved model.
|
||||
learner_model_param_ = LearnerModelParam(mparam_, mparam_.base_score);
|
||||
// attribute (saved a 0) in the saved model.
|
||||
std::string multi{"multi:"};
|
||||
if (!std::equal(multi.cbegin(), multi.cend(), tparam_.objective.cbegin())) {
|
||||
HostDeviceVector<float> t;
|
||||
t.HostVector().resize(1);
|
||||
t.HostVector().at(0) = mparam_.base_score;
|
||||
this->obj_->PredTransform(&t);
|
||||
auto base_score = t.HostVector().at(0);
|
||||
mparam_.base_score = base_score;
|
||||
}
|
||||
warn_old_model = true;
|
||||
}
|
||||
|
||||
learner_model_param_ =
|
||||
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score));
|
||||
if (attributes_.find("objective") != attributes_.cend()) {
|
||||
auto obj_str = attributes_.at("objective");
|
||||
auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()});
|
||||
|
||||
Reference in New Issue
Block a user