Fix metric serialization. (#9405)
This commit is contained in:
@@ -215,6 +215,34 @@ TEST(Learner, JsonModelIO) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Learner, ConfigIO) {
|
||||
bst_row_t n_samples = 128;
|
||||
bst_feature_t n_features = 12;
|
||||
std::shared_ptr<DMatrix> p_fmat{
|
||||
RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true, false, 2)};
|
||||
|
||||
auto serialised_model_tmp = std::string{};
|
||||
std::string eval_res_0;
|
||||
std::string eval_res_1;
|
||||
{
|
||||
std::unique_ptr<Learner> learner{Learner::Create({p_fmat})};
|
||||
learner->SetParams(Args{{"eval_metric", "ndcg"}, {"eval_metric", "map"}});
|
||||
learner->Configure();
|
||||
learner->UpdateOneIter(0, p_fmat);
|
||||
eval_res_0 = learner->EvalOneIter(0, {p_fmat}, {"Train"});
|
||||
common::MemoryBufferStream fo(&serialised_model_tmp);
|
||||
learner->Save(&fo);
|
||||
}
|
||||
|
||||
{
|
||||
common::MemoryBufferStream fi(&serialised_model_tmp);
|
||||
std::unique_ptr<Learner> learner{Learner::Create({p_fmat})};
|
||||
learner->Load(&fi);
|
||||
eval_res_1 = learner->EvalOneIter(0, {p_fmat}, {"Train"});
|
||||
}
|
||||
ASSERT_EQ(eval_res_0, eval_res_1);
|
||||
}
|
||||
|
||||
// Crashes the test runner if there are race condiditions.
|
||||
//
|
||||
// Build with additional cmake flags to enable thread sanitizer
|
||||
|
||||
Reference in New Issue
Block a user