Fix model parameter recovery (#4738)
This commit is contained in:
parent
851b5b3808
commit
3e2c472944
@ -287,9 +287,13 @@ class LearnerImpl : public Learner {
|
|||||||
metrics_.emplace_back(Metric::Create(name, &generic_param_));
|
metrics_.emplace_back(Metric::Create(name, &generic_param_));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg_["num_class"] = common::ToString(mparam_.num_class);
|
cfg_["num_class"] = common::ToString(mparam_.num_class);
|
||||||
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
|
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
|
||||||
|
|
||||||
|
auto n = tparam_.__DICT__();
|
||||||
|
cfg_.insert(n.cbegin(), n.cend());
|
||||||
|
|
||||||
gbm_->Configure({cfg_.cbegin(), cfg_.cend()});
|
gbm_->Configure({cfg_.cbegin(), cfg_.cend()});
|
||||||
obj_->Configure({cfg_.begin(), cfg_.end()});
|
obj_->Configure({cfg_.begin(), cfg_.end()});
|
||||||
|
|
||||||
|
|||||||
@ -111,6 +111,45 @@ TEST(Learner, Configuration) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Learner, ObjectiveParameter) {
|
||||||
|
using Arg = std::pair<std::string, std::string>;
|
||||||
|
size_t constexpr kRows = 10;
|
||||||
|
auto pp_dmat = CreateDMatrix(kRows, 10, 0);
|
||||||
|
auto p_dmat = *pp_dmat;
|
||||||
|
|
||||||
|
std::vector<bst_float> labels(kRows);
|
||||||
|
for (size_t i = 0; i < labels.size(); ++i) {
|
||||||
|
labels[i] = i;
|
||||||
|
}
|
||||||
|
p_dmat->Info().labels_.HostVector() = labels;
|
||||||
|
std::vector<std::shared_ptr<DMatrix>> mat {p_dmat};
|
||||||
|
|
||||||
|
std::unique_ptr<Learner> learner {Learner::Create(mat)};
|
||||||
|
learner->SetParams({Arg{"tree_method", "auto"},
|
||||||
|
Arg{"objective", "multi:softprob"},
|
||||||
|
Arg{"num_class", "10"}});
|
||||||
|
learner->UpdateOneIter(0, p_dmat.get());
|
||||||
|
auto attr_names = learner->GetConfigurationArguments();
|
||||||
|
ASSERT_EQ(attr_names.at("objective"), "multi:softprob");
|
||||||
|
|
||||||
|
dmlc::TemporaryDirectory tempdir;
|
||||||
|
const std::string fname = tempdir.path + "/model_para.bst";
|
||||||
|
|
||||||
|
{
|
||||||
|
// Create a scope to close the stream before next read.
|
||||||
|
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
|
||||||
|
learner->Save(fo.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r"));
|
||||||
|
std::unique_ptr<Learner> learner1 {Learner::Create(mat)};
|
||||||
|
learner1->Load(fi.get());
|
||||||
|
auto attr_names1 = learner1->GetConfigurationArguments();
|
||||||
|
ASSERT_EQ(attr_names1.at("objective"), "multi:softprob");
|
||||||
|
|
||||||
|
delete pp_dmat;
|
||||||
|
}
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
|
||||||
TEST(Learner, IO) {
|
TEST(Learner, IO) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user