From 0de7c474959c47003abb941dc7a657d8e010b96b Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 22 Jul 2023 08:39:21 +0800 Subject: [PATCH] Fix metric serialization. (#9405) --- src/learner.cc | 8 ++++---- tests/cpp/test_learner.cc | 28 ++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/src/learner.cc b/src/learner.cc index 2f453ea30..b2d6baff0 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -583,8 +583,9 @@ class LearnerConfiguration : public Learner { auto& objective_fn = learner_parameters["objective"]; obj_->SaveConfig(&objective_fn); - std::vector metrics(metrics_.size(), Json{Object{}}); + std::vector metrics(metrics_.size()); for (size_t i = 0; i < metrics_.size(); ++i) { + metrics[i] = Object{}; metrics_[i]->SaveConfig(&metrics[i]); } learner_parameters["metrics"] = Array(std::move(metrics)); @@ -807,14 +808,13 @@ class LearnerConfiguration : public Learner { void ConfigureMetrics(Args const& args) { for (auto const& name : metric_names_) { - auto DupCheck = [&name](std::unique_ptr const& m) { - return m->Name() != name; - }; + auto DupCheck = [&name](std::unique_ptr const& m) { return m->Name() != name; }; if (std::all_of(metrics_.begin(), metrics_.end(), DupCheck)) { metrics_.emplace_back(std::unique_ptr(Metric::Create(name, &ctx_))); mparam_.contain_eval_metrics = 1; } } + for (auto& p_metric : metrics_) { p_metric->Configure(args); } diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 2165c6c8d..3615f7587 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -215,6 +215,34 @@ TEST(Learner, JsonModelIO) { } } +TEST(Learner, ConfigIO) { + bst_row_t n_samples = 128; + bst_feature_t n_features = 12; + std::shared_ptr p_fmat{ + RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true, false, 2)}; + + auto serialised_model_tmp = std::string{}; + std::string eval_res_0; + std::string eval_res_1; + { + std::unique_ptr learner{Learner::Create({p_fmat})}; + learner->SetParams(Args{{"eval_metric", "ndcg"}, {"eval_metric", "map"}}); + learner->Configure(); + learner->UpdateOneIter(0, p_fmat); + eval_res_0 = learner->EvalOneIter(0, {p_fmat}, {"Train"}); + common::MemoryBufferStream fo(&serialised_model_tmp); + learner->Save(&fo); + } + + { + common::MemoryBufferStream fi(&serialised_model_tmp); + std::unique_ptr learner{Learner::Create({p_fmat})}; + learner->Load(&fi); + eval_res_1 = learner->EvalOneIter(0, {p_fmat}, {"Train"}); + } + ASSERT_EQ(eval_res_0, eval_res_1); +} + // Crashes the test runner if there are race condiditions. // // Build with additional cmake flags to enable thread sanitizer