Pass shared pointer instead of raw pointer to Learner. (#5302)

Extracted from https://github.com/dmlc/xgboost/pull/5220 .
This commit is contained in:
Jiaming Yuan 2020-02-11 14:16:38 +08:00 committed by GitHub
parent 2e0067e790
commit 29eeea709a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 97 additions and 73 deletions

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2015-2019 by Contributors
* Copyright 2015-2020 by Contributors
* \file learner.h
* \brief Learner interface that integrates objective, gbm and evaluation together.
* This is the user facing XGBoost training module.
@ -59,7 +59,7 @@ class Learner : public Model, public Configurable, public rabit::Serializable {
* \param iter current iteration number
* \param train reference to the data matrix.
*/
virtual void UpdateOneIter(int iter, DMatrix* train) = 0;
virtual void UpdateOneIter(int iter, std::shared_ptr<DMatrix> train) = 0;
/*!
* \brief Do customized gradient boosting with in_gpair.
* in_gair can be mutated after this call.
@ -68,7 +68,7 @@ class Learner : public Model, public Configurable, public rabit::Serializable {
* \param in_gpair The input gradient statistics.
*/
virtual void BoostOneIter(int iter,
DMatrix* train,
std::shared_ptr<DMatrix> train,
HostDeviceVector<GradientPair>* in_gpair) = 0;
/*!
* \brief evaluate the model for specific iteration using the configured metrics.
@ -78,7 +78,7 @@ class Learner : public Model, public Configurable, public rabit::Serializable {
* \return a string corresponding to the evaluation result
*/
virtual std::string EvalOneIter(int iter,
const std::vector<DMatrix*>& data_sets,
const std::vector<std::shared_ptr<DMatrix>>& data_sets,
const std::vector<std::string>& data_names) = 0;
/*!
* \brief get prediction given the model.
@ -92,7 +92,7 @@ class Learner : public Model, public Configurable, public rabit::Serializable {
* \param approx_contribs whether to approximate the feature contributions for speed
* \param pred_interactions whether to compute the feature pair contributions
*/
virtual void Predict(DMatrix* data,
virtual void Predict(std::shared_ptr<DMatrix> data,
bool output_margin,
HostDeviceVector<bst_float> *out_preds,
unsigned ntree_limit = 0,

View File

@ -1,4 +1,4 @@
// Copyright (c) 2014-2019 by Contributors
// Copyright (c) 2014-2020 by Contributors
#include <dmlc/thread_local.h>
#include <rabit/rabit.h>
#include <rabit/c_api.h>
@ -498,7 +498,7 @@ XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle,
auto *dtr =
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
bst->UpdateOneIter(iter, dtr->get());
bst->UpdateOneIter(iter, *dtr);
API_END();
}
@ -519,7 +519,7 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle,
tmp_gpair_h[i] = GradientPair(grad[i], hess[i]);
}
bst->BoostOneIter(0, dtr->get(), &tmp_gpair);
bst->BoostOneIter(0, *dtr, &tmp_gpair);
API_END();
}
@ -533,11 +533,11 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Learner*>(handle);
std::vector<DMatrix*> data_sets;
std::vector<std::shared_ptr<DMatrix>> data_sets;
std::vector<std::string> data_names;
for (xgboost::bst_ulong i = 0; i < len; ++i) {
data_sets.push_back(static_cast<std::shared_ptr<DMatrix>*>(dmats[i])->get());
data_sets.push_back(*static_cast<std::shared_ptr<DMatrix>*>(dmats[i]));
data_names.emplace_back(evnames[i]);
}
@ -560,7 +560,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
auto *bst = static_cast<Learner*>(handle);
HostDeviceVector<bst_float> tmp_preds;
bst->Predict(
static_cast<std::shared_ptr<DMatrix>*>(dmat)->get(),
*static_cast<std::shared_ptr<DMatrix>*>(dmat),
(option_mask & 1) != 0,
&tmp_preds, ntree_limit,
static_cast<bool>(training),

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2014-2019 by Contributors
* Copyright 2014-2020 by Contributors
* \file cli_main.cc
* \brief The command line interface program of xgboost.
* This file is not included in dynamic library.
@ -165,7 +165,7 @@ void CLITrain(const CLIParam& param) {
param.dsplit == 2));
std::vector<std::shared_ptr<DMatrix> > deval;
std::vector<std::shared_ptr<DMatrix> > cache_mats;
std::vector<DMatrix*> eval_datasets;
std::vector<std::shared_ptr<DMatrix>> eval_datasets;
cache_mats.push_back(dtrain);
for (size_t i = 0; i < param.eval_data_names.size(); ++i) {
deval.emplace_back(
@ -173,12 +173,12 @@ void CLITrain(const CLIParam& param) {
param.eval_data_paths[i],
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param.dsplit == 2)));
eval_datasets.push_back(deval.back().get());
eval_datasets.push_back(deval.back());
cache_mats.push_back(deval.back());
}
std::vector<std::string> eval_data_names = param.eval_data_names;
if (param.eval_train) {
eval_datasets.push_back(dtrain.get());
eval_datasets.push_back(dtrain);
eval_data_names.emplace_back("train");
}
// initialize the learner.
@ -203,7 +203,7 @@ void CLITrain(const CLIParam& param) {
double elapsed = dmlc::GetTime() - start;
if (version % 2 == 0) {
LOG(INFO) << "boosting round " << i << ", " << elapsed << " sec elapsed";
learner->UpdateOneIter(i, dtrain.get());
learner->UpdateOneIter(i, dtrain);
if (learner->AllowLazyCheckPoint()) {
rabit::LazyCheckPoint(learner.get());
} else {
@ -305,7 +305,7 @@ void CLIPredict(const CLIParam& param) {
CHECK_NE(param.test_path, "NULL")
<< "Test dataset parameter test:data must be specified.";
// load data
std::unique_ptr<DMatrix> dtest(
std::shared_ptr<DMatrix> dtest(
DMatrix::Load(
param.test_path,
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
@ -321,7 +321,7 @@ void CLIPredict(const CLIParam& param) {
LOG(INFO) << "start prediction...";
HostDeviceVector<bst_float> preds;
learner->Predict(dtest.get(), param.pred_margin, &preds, param.ntree_limit);
learner->Predict(dtest, param.pred_margin, &preds, param.ntree_limit);
LOG(CONSOLE) << "writing prediction to " << param.name_pred;
std::unique_ptr<dmlc::Stream> fo(

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2017-2019 XGBoost contributors
* Copyright 2017-2020 XGBoost contributors
*/
#pragma once
#include <thrust/device_ptr.h>
@ -9,7 +9,6 @@
#include <thrust/system_error.h>
#include <thrust/logical.h>
#include <omp.h>
#include <rabit/rabit.h>
#include <cub/cub.cuh>
#include <cub/util_allocator.cuh>

View File

@ -1,11 +1,12 @@
/*!
* Copyright 2019 XGBoost contributors
* Copyright 2019-2020 XGBoost contributors
* \file observer.h
*/
#ifndef XGBOOST_COMMON_OBSERVER_H_
#define XGBOOST_COMMON_OBSERVER_H_
#include <iostream>
#include <limits>
#include <string>
#include <vector>
@ -63,7 +64,8 @@ class TrainingObserver {
}
/*\brief Observe data hosted by `std::vector'. */
template <typename T>
void Observe(std::vector<T> const& h_vec, std::string name) const {
void Observe(std::vector<T> const& h_vec, std::string name,
size_t n = std::numeric_limits<std::size_t>::max()) const {
if (XGBOOST_EXPECT(!observe_, true)) { return; }
OBSERVER_PRINT << "Procedure: " << name << OBSERVER_ENDL;
@ -72,20 +74,25 @@ class TrainingObserver {
if (i % 8 == 0) {
OBSERVER_PRINT << OBSERVER_NEWLINE;
}
if ((i + 1) == n) {
break;
}
}
OBSERVER_PRINT << OBSERVER_ENDL;
}
/*\brief Observe data hosted by `HostDeviceVector'. */
template <typename T>
void Observe(HostDeviceVector<T> const& vec, std::string name) const {
void Observe(HostDeviceVector<T> const& vec, std::string name,
size_t n = std::numeric_limits<std::size_t>::max()) const {
if (XGBOOST_EXPECT(!observe_, true)) { return; }
auto const& h_vec = vec.HostVector();
this->Observe(h_vec, name);
this->Observe(h_vec, name, n);
}
template <typename T>
void Observe(HostDeviceVector<T>* vec, std::string name) const {
void Observe(HostDeviceVector<T>* vec, std::string name,
size_t n = std::numeric_limits<std::size_t>::max()) const {
if (XGBOOST_EXPECT(!observe_, true)) { return; }
this->Observe(*vec, name);
this->Observe(*vec, name, n);
}
/*\brief Observe objects with `XGBoostParamer' type. */

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2014 by Contributors
* Copyright 2014-2020 by Contributors
* \file gblinear.cc
* \brief Implementation of Linear booster, with L1/L2 regularization: Elastic Net
* the update rule is parallel coordinate descent (shotgun)
@ -239,7 +239,7 @@ class GBLinear : public GradientBooster {
void PredictBatchInternal(DMatrix *p_fmat,
std::vector<bst_float> *out_preds) {
monitor_.Start("PredictBatchInternal");
model_.LazyInitModel();
model_.LazyInitModel();
std::vector<bst_float> &preds = *out_preds;
const auto& base_margin = p_fmat->Info().base_margin_.ConstHostVector();
// start collecting the prediction
@ -250,6 +250,9 @@ class GBLinear : public GradientBooster {
// k is number of group
// parallel over local batch
const auto nsize = static_cast<omp_ulong>(batch.Size());
if (base_margin.size() != 0) {
CHECK_EQ(base_margin.size(), nsize * ngroup);
}
#pragma omp parallel for schedule(static)
for (omp_ulong i = 0; i < nsize; ++i) {
const size_t ridx = batch.base_rowid + i;

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2014-2019 by Contributors
* Copyright 2014-2020 by Contributors
* \file learner.cc
* \brief Implementation of learning algorithm.
* \author Tianqi Chen
@ -691,7 +691,7 @@ class LearnerImpl : public Learner {
return gbm_->DumpModel(fmap, with_stats, format);
}
void UpdateOneIter(int iter, DMatrix* train) override {
void UpdateOneIter(int iter, std::shared_ptr<DMatrix> train) override {
monitor_.Start("UpdateOneIter");
TrainingObserver::Instance().Update(iter);
this->Configure();
@ -699,23 +699,23 @@ class LearnerImpl : public Learner {
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter);
}
this->CheckDataSplitMode();
this->ValidateDMatrix(train);
this->ValidateDMatrix(train.get());
monitor_.Start("PredictRaw");
this->PredictRaw(train, &preds_[train], true);
this->PredictRaw(train.get(), &preds_[train.get()], true);
monitor_.Stop("PredictRaw");
TrainingObserver::Instance().Observe(preds_[train], "Predictions");
TrainingObserver::Instance().Observe(preds_[train.get()], "Predictions");
monitor_.Start("GetGradient");
obj_->GetGradient(preds_[train], train->Info(), iter, &gpair_);
obj_->GetGradient(preds_[train.get()], train->Info(), iter, &gpair_);
monitor_.Stop("GetGradient");
TrainingObserver::Instance().Observe(gpair_, "Gradients");
gbm_->DoBoost(train, &gpair_, obj_.get());
gbm_->DoBoost(train.get(), &gpair_, obj_.get());
monitor_.Stop("UpdateOneIter");
}
void BoostOneIter(int iter, DMatrix* train,
void BoostOneIter(int iter, std::shared_ptr<DMatrix> train,
HostDeviceVector<GradientPair>* in_gpair) override {
monitor_.Start("BoostOneIter");
this->Configure();
@ -723,13 +723,13 @@ class LearnerImpl : public Learner {
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter);
}
this->CheckDataSplitMode();
this->ValidateDMatrix(train);
this->ValidateDMatrix(train.get());
gbm_->DoBoost(train, in_gpair);
gbm_->DoBoost(train.get(), in_gpair);
monitor_.Stop("BoostOneIter");
}
std::string EvalOneIter(int iter, const std::vector<DMatrix*>& data_sets,
std::string EvalOneIter(int iter, const std::vector<std::shared_ptr<DMatrix>>& data_sets,
const std::vector<std::string>& data_names) override {
monitor_.Start("EvalOneIter");
this->Configure();
@ -741,9 +741,9 @@ class LearnerImpl : public Learner {
metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
}
for (size_t i = 0; i < data_sets.size(); ++i) {
DMatrix * dmat = data_sets[i];
DMatrix * dmat = data_sets[i].get();
this->ValidateDMatrix(dmat);
this->PredictRaw(data_sets[i], &preds_[dmat], false);
this->PredictRaw(dmat, &preds_[dmat], false);
obj_->EvalTransform(&preds_[dmat]);
for (auto& ev : metrics_) {
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
@ -805,7 +805,7 @@ class LearnerImpl : public Learner {
return generic_parameters_;
}
void Predict(DMatrix* data, bool output_margin,
void Predict(std::shared_ptr<DMatrix> data, bool output_margin,
HostDeviceVector<bst_float>* out_preds, unsigned ntree_limit,
bool training,
bool pred_leaf, bool pred_contribs, bool approx_contribs,
@ -816,14 +816,14 @@ class LearnerImpl : public Learner {
this->Configure();
CHECK_LE(multiple_predictions, 1) << "Perform one kind of prediction at a time.";
if (pred_contribs) {
gbm_->PredictContribution(data, &out_preds->HostVector(), ntree_limit, approx_contribs);
gbm_->PredictContribution(data.get(), &out_preds->HostVector(), ntree_limit, approx_contribs);
} else if (pred_interactions) {
gbm_->PredictInteractionContributions(data, &out_preds->HostVector(), ntree_limit,
gbm_->PredictInteractionContributions(data.get(), &out_preds->HostVector(), ntree_limit,
approx_contribs);
} else if (pred_leaf) {
gbm_->PredictLeaf(data, &out_preds->HostVector(), ntree_limit);
gbm_->PredictLeaf(data.get(), &out_preds->HostVector(), ntree_limit);
} else {
this->PredictRaw(data, out_preds, training, ntree_limit);
this->PredictRaw(data.get(), out_preds, training, ntree_limit);
if (!output_margin) {
obj_->PredTransform(out_preds);
}

View File

@ -1,4 +1,6 @@
// Copyright (c) 2019 by Contributors
/*!
* Copyright 2019-2020 XGBoost contributors
*/
#include <gtest/gtest.h>
#include <xgboost/version_config.h>
#include <xgboost/c_api.h>
@ -92,7 +94,7 @@ TEST(c_api, ConfigIO) {
std::shared_ptr<Learner> learner { Learner::Create(mat) };
BoosterHandle handle = learner.get();
learner->UpdateOneIter(0, p_dmat.get());
learner->UpdateOneIter(0, p_dmat);
char const* out[1];
bst_ulong len {0};
@ -127,7 +129,7 @@ TEST(c_api, JsonModelIO) {
std::shared_ptr<Learner> learner { Learner::Create(mat) };
learner->UpdateOneIter(0, p_dmat.get());
learner->UpdateOneIter(0, p_dmat);
BoosterHandle handle = learner.get();
std::string modelfile_0 = tempdir.path + "/model_0.json";

View File

@ -1,3 +1,6 @@
/*!
* Copyright 2019-2020 XGBoost contributors
*/
#include <gtest/gtest.h>
#include <dmlc/filesystem.h>
#include <xgboost/generic_parameters.h>
@ -62,7 +65,7 @@ TEST(GBTree, ChoosePredictor) {
auto learner = std::unique_ptr<Learner>(Learner::Create({p_dmat}));
learner->SetParams(Args{{"tree_method", "gpu_hist"}, {"gpu_id", "0"}});
for (size_t i = 0; i < 4; ++i) {
learner->UpdateOneIter(i, p_dmat.get());
learner->UpdateOneIter(i, p_dmat);
}
ASSERT_TRUE(data.HostCanWrite());
dmlc::TemporaryDirectory tempdir;
@ -81,7 +84,7 @@ TEST(GBTree, ChoosePredictor) {
}
learner->SetParams(Args{{"tree_method", "gpu_hist"}, {"gpu_id", "0"}});
for (size_t i = 0; i < 4; ++i) {
learner->UpdateOneIter(i, p_dmat.get());
learner->UpdateOneIter(i, p_dmat);
}
ASSERT_TRUE(data.HostCanWrite());
@ -94,7 +97,7 @@ TEST(GBTree, ChoosePredictor) {
learner = std::unique_ptr<Learner>(Learner::Create({p_dmat}));
learner->SetParams(Args{{"tree_method", "gpu_hist"}, {"gpu_id", "0"}});
for (size_t i = 0; i < 4; ++i) {
learner->UpdateOneIter(i, p_dmat.get());
learner->UpdateOneIter(i, p_dmat);
}
// data is not pulled back into host
ASSERT_FALSE(data.HostCanWrite());
@ -196,13 +199,13 @@ TEST(Dart, Prediction) {
learner->Configure();
for (size_t i = 0; i < 16; ++i) {
learner->UpdateOneIter(i, p_mat.get());
learner->UpdateOneIter(i, p_mat);
}
HostDeviceVector<float> predts_training;
learner->Predict(p_mat.get(), false, &predts_training, 0, true);
learner->Predict(p_mat, false, &predts_training, 0, true);
HostDeviceVector<float> predts_inference;
learner->Predict(p_mat.get(), false, &predts_inference, 0, false);
learner->Predict(p_mat, false, &predts_inference, 0, false);
auto& h_predts_training = predts_training.ConstHostVector();
auto& h_predts_inference = predts_inference.ConstHostVector();

View File

@ -1,4 +1,6 @@
// Copyright by Contributors
/*!
* Copyright 2017-2020 XGBoost contributors
*/
#include <gtest/gtest.h>
#include <vector>
#include "helpers.h"
@ -79,13 +81,13 @@ TEST(Learner, CheckGroup) {
std::vector<std::shared_ptr<xgboost::DMatrix>> mat = {p_mat};
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
learner->SetParams({Arg{"objective", "rank:pairwise"}});
EXPECT_NO_THROW(learner->UpdateOneIter(0, p_mat.get()));
EXPECT_NO_THROW(learner->UpdateOneIter(0, p_mat));
group.resize(kNumGroups+1);
group[3] = 4;
group[4] = 1;
p_mat->Info().SetInfo("group", group.data(), DataType::kUInt32, kNumGroups+1);
EXPECT_ANY_THROW(learner->UpdateOneIter(0, p_mat.get()));
EXPECT_ANY_THROW(learner->UpdateOneIter(0, p_mat));
delete pp_mat;
}
@ -107,7 +109,7 @@ TEST(Learner, SLOW_CheckMultiBatch) {
std::vector<std::shared_ptr<DMatrix>> mat{dmat};
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
learner->SetParams(Args{{"objective", "binary:logistic"}});
learner->UpdateOneIter(0, dmat.get());
learner->UpdateOneIter(0, dmat);
}
TEST(Learner, Configuration) {
@ -142,6 +144,7 @@ TEST(Learner, JsonModelIO) {
auto pp_dmat = CreateDMatrix(kRows, 10, 0);
std::shared_ptr<DMatrix> p_dmat {*pp_dmat};
p_dmat->Info().labels_.Resize(kRows);
CHECK_NE(p_dmat->Info().num_col_, 0);
{
std::unique_ptr<Learner> learner { Learner::Create({p_dmat}) };
@ -160,7 +163,7 @@ TEST(Learner, JsonModelIO) {
{
std::unique_ptr<Learner> learner { Learner::Create({p_dmat}) };
for (int32_t iter = 0; iter < kIters; ++iter) {
learner->UpdateOneIter(iter, p_dmat.get());
learner->UpdateOneIter(iter, p_dmat);
}
learner->SetAttr("best_score", "15.2");
@ -197,20 +200,20 @@ TEST(Learner, GPUConfiguration) {
std::unique_ptr<Learner> learner {Learner::Create(mat)};
learner->SetParams({Arg{"booster", "gblinear"},
Arg{"updater", "gpu_coord_descent"}});
learner->UpdateOneIter(0, p_dmat.get());
learner->UpdateOneIter(0, p_dmat);
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
}
{
std::unique_ptr<Learner> learner {Learner::Create(mat)};
learner->SetParams({Arg{"tree_method", "gpu_hist"}});
learner->UpdateOneIter(0, p_dmat.get());
learner->UpdateOneIter(0, p_dmat);
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
}
{
// with CPU algorithm
std::unique_ptr<Learner> learner {Learner::Create(mat)};
learner->SetParams({Arg{"tree_method", "hist"}});
learner->UpdateOneIter(0, p_dmat.get());
learner->UpdateOneIter(0, p_dmat);
ASSERT_EQ(learner->GetGenericParameter().gpu_id, -1);
}
{
@ -218,7 +221,7 @@ TEST(Learner, GPUConfiguration) {
std::unique_ptr<Learner> learner {Learner::Create(mat)};
learner->SetParams({Arg{"tree_method", "hist"},
Arg{"gpu_id", "0"}});
learner->UpdateOneIter(0, p_dmat.get());
learner->UpdateOneIter(0, p_dmat);
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
}
{
@ -228,7 +231,7 @@ TEST(Learner, GPUConfiguration) {
std::unique_ptr<Learner> learner {Learner::Create(mat)};
learner->SetParams({Arg{"tree_method", "hist"},
Arg{"predictor", "gpu_predictor"}});
learner->UpdateOneIter(0, p_dmat.get());
learner->UpdateOneIter(0, p_dmat);
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
}

View File

@ -1,3 +1,4 @@
// Copyright (c) 2019-2020 by Contributors
#include <gtest/gtest.h>
#include <dmlc/filesystem.h>
#include <string>
@ -24,12 +25,13 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr
std::vector<std::string> dumped_0;
std::string model_at_kiter;
// Train for kIters.
{
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
std::unique_ptr<Learner> learner {Learner::Create({p_dmat})};
learner->SetParams(args);
for (int32_t iter = 0; iter < kIters; ++iter) {
learner->UpdateOneIter(iter, p_dmat.get());
learner->UpdateOneIter(iter, p_dmat);
}
dumped_0 = learner->DumpModel(fmap, true, "json");
learner->Save(fo.get());
@ -38,6 +40,7 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr
learner->Save(&mem_out);
}
// Assert dumped model is same after loading
std::vector<std::string> dumped_1;
{
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r"));
@ -73,7 +76,7 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr
}
for (int32_t iter = kIters; iter < 2 * kIters; ++iter) {
learner->UpdateOneIter(iter, p_dmat.get());
learner->UpdateOneIter(iter, p_dmat);
}
common::MemoryBufferStream fo(&continued_model);
learner->Save(&fo);
@ -84,7 +87,7 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr
std::unique_ptr<Learner> learner{Learner::Create({p_dmat})};
learner->SetParams(args);
for (int32_t iter = 0; iter < 2 * kIters; ++iter) {
learner->UpdateOneIter(iter, p_dmat.get());
learner->UpdateOneIter(iter, p_dmat);
// Verify model is same at the same iteration during two training
// sessions.
@ -98,6 +101,7 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr
common::MemoryBufferStream fo(&model_at_2kiter);
learner->Save(&fo);
}
Json m_0 = Json::Load(StringView{continued_model.c_str(), continued_model.size()});
Json m_1 = Json::Load(StringView{model_at_2kiter.c_str(), model_at_2kiter.size()});
ASSERT_EQ(m_0, m_1);
@ -127,7 +131,7 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr
}
for (int32_t iter = kIters; iter < 2 * kIters; ++iter) {
learner->UpdateOneIter(iter, p_dmat.get());
learner->UpdateOneIter(iter, p_dmat);
}
serialised_model_tmp = std::string{};
common::MemoryBufferStream fo(&serialised_model_tmp);
@ -306,7 +310,7 @@ TEST_F(SerializationTest, ConfigurationCount) {
learner->SetParam("enable_experimental_json_serialization", "1");
for (size_t i = 0; i < 10; ++i) {
learner->UpdateOneIter(i, p_dmat.get());
learner->UpdateOneIter(i, p_dmat);
}
common::MemoryBufferStream fo(&model_str);
learner->Save(&fo);
@ -317,7 +321,7 @@ TEST_F(SerializationTest, ConfigurationCount) {
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
learner->Load(&fi);
for (size_t i = 0; i < 10; ++i) {
learner->UpdateOneIter(i, p_dmat.get());
learner->UpdateOneIter(i, p_dmat);
}
}

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2017-2019 XGBoost contributors
* Copyright 2017-2020 XGBoost contributors
*/
#include <thrust/device_vector.h>
#include <dmlc/filesystem.h>
@ -387,6 +387,7 @@ TEST(GpuHist, UniformSampling) {
constexpr size_t kRows = 4096;
constexpr size_t kCols = 2;
constexpr float kSubsample = 0.99;
common::GlobalRandom().seed(1994);
// Create an in-memory DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
@ -415,6 +416,7 @@ TEST(GpuHist, GradientBasedSampling) {
constexpr size_t kRows = 4096;
constexpr size_t kCols = 2;
constexpr float kSubsample = 0.99;
common::GlobalRandom().seed(1994);
// Create an in-memory DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
@ -478,6 +480,7 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
constexpr size_t kPageSize = 1024;
constexpr float kSubsample = 0.5;
const std::string kSamplingMethod = "gradient_based";
common::GlobalRandom().seed(0);
// Create an in-memory DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
@ -503,7 +506,7 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
auto preds_h = preds.ConstHostVector();
auto preds_ext_h = preds_ext.ConstHostVector();
for (int i = 0; i < kRows; i++) {
EXPECT_NEAR(preds_h[i], preds_ext_h[i], 3e-3);
EXPECT_NEAR(preds_h[i], preds_ext_h[i], 2e-3);
}
}