Fix metric serialization. (#9405)
This commit is contained in:
parent
dbd5309b55
commit
0de7c47495
@ -583,8 +583,9 @@ class LearnerConfiguration : public Learner {
|
|||||||
auto& objective_fn = learner_parameters["objective"];
|
auto& objective_fn = learner_parameters["objective"];
|
||||||
obj_->SaveConfig(&objective_fn);
|
obj_->SaveConfig(&objective_fn);
|
||||||
|
|
||||||
std::vector<Json> metrics(metrics_.size(), Json{Object{}});
|
std::vector<Json> metrics(metrics_.size());
|
||||||
for (size_t i = 0; i < metrics_.size(); ++i) {
|
for (size_t i = 0; i < metrics_.size(); ++i) {
|
||||||
|
metrics[i] = Object{};
|
||||||
metrics_[i]->SaveConfig(&metrics[i]);
|
metrics_[i]->SaveConfig(&metrics[i]);
|
||||||
}
|
}
|
||||||
learner_parameters["metrics"] = Array(std::move(metrics));
|
learner_parameters["metrics"] = Array(std::move(metrics));
|
||||||
@ -807,14 +808,13 @@ class LearnerConfiguration : public Learner {
|
|||||||
|
|
||||||
void ConfigureMetrics(Args const& args) {
|
void ConfigureMetrics(Args const& args) {
|
||||||
for (auto const& name : metric_names_) {
|
for (auto const& name : metric_names_) {
|
||||||
auto DupCheck = [&name](std::unique_ptr<Metric> const& m) {
|
auto DupCheck = [&name](std::unique_ptr<Metric> const& m) { return m->Name() != name; };
|
||||||
return m->Name() != name;
|
|
||||||
};
|
|
||||||
if (std::all_of(metrics_.begin(), metrics_.end(), DupCheck)) {
|
if (std::all_of(metrics_.begin(), metrics_.end(), DupCheck)) {
|
||||||
metrics_.emplace_back(std::unique_ptr<Metric>(Metric::Create(name, &ctx_)));
|
metrics_.emplace_back(std::unique_ptr<Metric>(Metric::Create(name, &ctx_)));
|
||||||
mparam_.contain_eval_metrics = 1;
|
mparam_.contain_eval_metrics = 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& p_metric : metrics_) {
|
for (auto& p_metric : metrics_) {
|
||||||
p_metric->Configure(args);
|
p_metric->Configure(args);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -215,6 +215,34 @@ TEST(Learner, JsonModelIO) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Learner, ConfigIO) {
|
||||||
|
bst_row_t n_samples = 128;
|
||||||
|
bst_feature_t n_features = 12;
|
||||||
|
std::shared_ptr<DMatrix> p_fmat{
|
||||||
|
RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true, false, 2)};
|
||||||
|
|
||||||
|
auto serialised_model_tmp = std::string{};
|
||||||
|
std::string eval_res_0;
|
||||||
|
std::string eval_res_1;
|
||||||
|
{
|
||||||
|
std::unique_ptr<Learner> learner{Learner::Create({p_fmat})};
|
||||||
|
learner->SetParams(Args{{"eval_metric", "ndcg"}, {"eval_metric", "map"}});
|
||||||
|
learner->Configure();
|
||||||
|
learner->UpdateOneIter(0, p_fmat);
|
||||||
|
eval_res_0 = learner->EvalOneIter(0, {p_fmat}, {"Train"});
|
||||||
|
common::MemoryBufferStream fo(&serialised_model_tmp);
|
||||||
|
learner->Save(&fo);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
common::MemoryBufferStream fi(&serialised_model_tmp);
|
||||||
|
std::unique_ptr<Learner> learner{Learner::Create({p_fmat})};
|
||||||
|
learner->Load(&fi);
|
||||||
|
eval_res_1 = learner->EvalOneIter(0, {p_fmat}, {"Train"});
|
||||||
|
}
|
||||||
|
ASSERT_EQ(eval_res_0, eval_res_1);
|
||||||
|
}
|
||||||
|
|
||||||
// Crashes the test runner if there are race condiditions.
|
// Crashes the test runner if there are race condiditions.
|
||||||
//
|
//
|
||||||
// Build with additional cmake flags to enable thread sanitizer
|
// Build with additional cmake flags to enable thread sanitizer
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user