Pass shared pointer instead of raw pointer to Learner. (#5302)

Extracted from https://github.com/dmlc/xgboost/pull/5220 .
This commit is contained in:
Jiaming Yuan
2020-02-11 14:16:38 +08:00
committed by GitHub
parent 2e0067e790
commit 29eeea709a
12 changed files with 97 additions and 73 deletions

View File

@@ -1,3 +1,6 @@
/*!
* Copyright 2019-2020 XGBoost contributors
*/
#include <gtest/gtest.h>
#include <dmlc/filesystem.h>
#include <xgboost/generic_parameters.h>
@@ -62,7 +65,7 @@ TEST(GBTree, ChoosePredictor) {
auto learner = std::unique_ptr<Learner>(Learner::Create({p_dmat}));
learner->SetParams(Args{{"tree_method", "gpu_hist"}, {"gpu_id", "0"}});
for (size_t i = 0; i < 4; ++i) {
learner->UpdateOneIter(i, p_dmat.get());
learner->UpdateOneIter(i, p_dmat);
}
ASSERT_TRUE(data.HostCanWrite());
dmlc::TemporaryDirectory tempdir;
@@ -81,7 +84,7 @@ TEST(GBTree, ChoosePredictor) {
}
learner->SetParams(Args{{"tree_method", "gpu_hist"}, {"gpu_id", "0"}});
for (size_t i = 0; i < 4; ++i) {
learner->UpdateOneIter(i, p_dmat.get());
learner->UpdateOneIter(i, p_dmat);
}
ASSERT_TRUE(data.HostCanWrite());
@@ -94,7 +97,7 @@ TEST(GBTree, ChoosePredictor) {
learner = std::unique_ptr<Learner>(Learner::Create({p_dmat}));
learner->SetParams(Args{{"tree_method", "gpu_hist"}, {"gpu_id", "0"}});
for (size_t i = 0; i < 4; ++i) {
learner->UpdateOneIter(i, p_dmat.get());
learner->UpdateOneIter(i, p_dmat);
}
// data is not pulled back into host
ASSERT_FALSE(data.HostCanWrite());
@@ -196,13 +199,13 @@ TEST(Dart, Prediction) {
learner->Configure();
for (size_t i = 0; i < 16; ++i) {
learner->UpdateOneIter(i, p_mat.get());
learner->UpdateOneIter(i, p_mat);
}
HostDeviceVector<float> predts_training;
learner->Predict(p_mat.get(), false, &predts_training, 0, true);
learner->Predict(p_mat, false, &predts_training, 0, true);
HostDeviceVector<float> predts_inference;
learner->Predict(p_mat.get(), false, &predts_inference, 0, false);
learner->Predict(p_mat, false, &predts_inference, 0, false);
auto& h_predts_training = predts_training.ConstHostVector();
auto& h_predts_inference = predts_inference.ConstHostVector();