Pass shared pointer instead of raw pointer to Learner. (#5302)

Extracted from https://github.com/dmlc/xgboost/pull/5220 .
This commit is contained in:
Jiaming Yuan
2020-02-11 14:16:38 +08:00
committed by GitHub
parent 2e0067e790
commit 29eeea709a
12 changed files with 97 additions and 73 deletions

View File

@@ -1,3 +1,4 @@
// Copyright (c) 2019-2020 by Contributors
#include <gtest/gtest.h>
#include <dmlc/filesystem.h>
#include <string>
@@ -24,12 +25,13 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr
std::vector<std::string> dumped_0;
std::string model_at_kiter;
// Train for kIters.
{
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
std::unique_ptr<Learner> learner {Learner::Create({p_dmat})};
learner->SetParams(args);
for (int32_t iter = 0; iter < kIters; ++iter) {
learner->UpdateOneIter(iter, p_dmat.get());
learner->UpdateOneIter(iter, p_dmat);
}
dumped_0 = learner->DumpModel(fmap, true, "json");
learner->Save(fo.get());
@@ -38,6 +40,7 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr
learner->Save(&mem_out);
}
// Assert dumped model is same after loading
std::vector<std::string> dumped_1;
{
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r"));
@@ -73,7 +76,7 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr
}
for (int32_t iter = kIters; iter < 2 * kIters; ++iter) {
learner->UpdateOneIter(iter, p_dmat.get());
learner->UpdateOneIter(iter, p_dmat);
}
common::MemoryBufferStream fo(&continued_model);
learner->Save(&fo);
@@ -84,7 +87,7 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr
std::unique_ptr<Learner> learner{Learner::Create({p_dmat})};
learner->SetParams(args);
for (int32_t iter = 0; iter < 2 * kIters; ++iter) {
learner->UpdateOneIter(iter, p_dmat.get());
learner->UpdateOneIter(iter, p_dmat);
// Verify model is same at the same iteration during two training
// sessions.
@@ -98,6 +101,7 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr
common::MemoryBufferStream fo(&model_at_2kiter);
learner->Save(&fo);
}
Json m_0 = Json::Load(StringView{continued_model.c_str(), continued_model.size()});
Json m_1 = Json::Load(StringView{model_at_2kiter.c_str(), model_at_2kiter.size()});
ASSERT_EQ(m_0, m_1);
@@ -127,7 +131,7 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr
}
for (int32_t iter = kIters; iter < 2 * kIters; ++iter) {
learner->UpdateOneIter(iter, p_dmat.get());
learner->UpdateOneIter(iter, p_dmat);
}
serialised_model_tmp = std::string{};
common::MemoryBufferStream fo(&serialised_model_tmp);
@@ -306,7 +310,7 @@ TEST_F(SerializationTest, ConfigurationCount) {
learner->SetParam("enable_experimental_json_serialization", "1");
for (size_t i = 0; i < 10; ++i) {
learner->UpdateOneIter(i, p_dmat.get());
learner->UpdateOneIter(i, p_dmat);
}
common::MemoryBufferStream fo(&model_str);
learner->Save(&fo);
@@ -317,7 +321,7 @@ TEST_F(SerializationTest, ConfigurationCount) {
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
learner->Load(&fi);
for (size_t i = 0; i < 10; ++i) {
learner->UpdateOneIter(i, p_dmat.get());
learner->UpdateOneIter(i, p_dmat);
}
}