diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index 1f78382c7..7dd0a2f2d 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -1,5 +1,5 @@ /*! - * Copyright 2015-2019 by Contributors + * Copyright 2015-2020 by Contributors * \file learner.h * \brief Learner interface that integrates objective, gbm and evaluation together. * This is the user facing XGBoost training module. @@ -59,7 +59,7 @@ class Learner : public Model, public Configurable, public rabit::Serializable { * \param iter current iteration number * \param train reference to the data matrix. */ - virtual void UpdateOneIter(int iter, DMatrix* train) = 0; + virtual void UpdateOneIter(int iter, std::shared_ptr train) = 0; /*! * \brief Do customized gradient boosting with in_gpair. * in_gair can be mutated after this call. @@ -68,7 +68,7 @@ class Learner : public Model, public Configurable, public rabit::Serializable { * \param in_gpair The input gradient statistics. */ virtual void BoostOneIter(int iter, - DMatrix* train, + std::shared_ptr train, HostDeviceVector* in_gpair) = 0; /*! * \brief evaluate the model for specific iteration using the configured metrics. @@ -78,7 +78,7 @@ class Learner : public Model, public Configurable, public rabit::Serializable { * \return a string corresponding to the evaluation result */ virtual std::string EvalOneIter(int iter, - const std::vector& data_sets, + const std::vector>& data_sets, const std::vector& data_names) = 0; /*! * \brief get prediction given the model. @@ -92,7 +92,7 @@ class Learner : public Model, public Configurable, public rabit::Serializable { * \param approx_contribs whether to approximate the feature contributions for speed * \param pred_interactions whether to compute the feature pair contributions */ - virtual void Predict(DMatrix* data, + virtual void Predict(std::shared_ptr data, bool output_margin, HostDeviceVector *out_preds, unsigned ntree_limit = 0, diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 37cbd6dd2..ccc49d91a 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2014-2019 by Contributors +// Copyright (c) 2014-2020 by Contributors #include #include #include @@ -498,7 +498,7 @@ XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle, auto *dtr = static_cast*>(dtrain); - bst->UpdateOneIter(iter, dtr->get()); + bst->UpdateOneIter(iter, *dtr); API_END(); } @@ -519,7 +519,7 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle, tmp_gpair_h[i] = GradientPair(grad[i], hess[i]); } - bst->BoostOneIter(0, dtr->get(), &tmp_gpair); + bst->BoostOneIter(0, *dtr, &tmp_gpair); API_END(); } @@ -533,11 +533,11 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle, API_BEGIN(); CHECK_HANDLE(); auto* bst = static_cast(handle); - std::vector data_sets; + std::vector> data_sets; std::vector data_names; for (xgboost::bst_ulong i = 0; i < len; ++i) { - data_sets.push_back(static_cast*>(dmats[i])->get()); + data_sets.push_back(*static_cast*>(dmats[i])); data_names.emplace_back(evnames[i]); } @@ -560,7 +560,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle, auto *bst = static_cast(handle); HostDeviceVector tmp_preds; bst->Predict( - static_cast*>(dmat)->get(), + *static_cast*>(dmat), (option_mask & 1) != 0, &tmp_preds, ntree_limit, static_cast(training), diff --git a/src/cli_main.cc b/src/cli_main.cc index 482e4f1b0..cda2f34a7 100644 --- a/src/cli_main.cc +++ b/src/cli_main.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2014-2019 by Contributors + * Copyright 2014-2020 by Contributors * \file cli_main.cc * \brief The command line interface program of xgboost. * This file is not included in dynamic library. @@ -165,7 +165,7 @@ void CLITrain(const CLIParam& param) { param.dsplit == 2)); std::vector > deval; std::vector > cache_mats; - std::vector eval_datasets; + std::vector> eval_datasets; cache_mats.push_back(dtrain); for (size_t i = 0; i < param.eval_data_names.size(); ++i) { deval.emplace_back( @@ -173,12 +173,12 @@ void CLITrain(const CLIParam& param) { param.eval_data_paths[i], ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(), param.dsplit == 2))); - eval_datasets.push_back(deval.back().get()); + eval_datasets.push_back(deval.back()); cache_mats.push_back(deval.back()); } std::vector eval_data_names = param.eval_data_names; if (param.eval_train) { - eval_datasets.push_back(dtrain.get()); + eval_datasets.push_back(dtrain); eval_data_names.emplace_back("train"); } // initialize the learner. @@ -203,7 +203,7 @@ void CLITrain(const CLIParam& param) { double elapsed = dmlc::GetTime() - start; if (version % 2 == 0) { LOG(INFO) << "boosting round " << i << ", " << elapsed << " sec elapsed"; - learner->UpdateOneIter(i, dtrain.get()); + learner->UpdateOneIter(i, dtrain); if (learner->AllowLazyCheckPoint()) { rabit::LazyCheckPoint(learner.get()); } else { @@ -305,7 +305,7 @@ void CLIPredict(const CLIParam& param) { CHECK_NE(param.test_path, "NULL") << "Test dataset parameter test:data must be specified."; // load data - std::unique_ptr dtest( + std::shared_ptr dtest( DMatrix::Load( param.test_path, ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(), @@ -321,7 +321,7 @@ void CLIPredict(const CLIParam& param) { LOG(INFO) << "start prediction..."; HostDeviceVector preds; - learner->Predict(dtest.get(), param.pred_margin, &preds, param.ntree_limit); + learner->Predict(dtest, param.pred_margin, &preds, param.ntree_limit); LOG(CONSOLE) << "writing prediction to " << param.name_pred; std::unique_ptr fo( diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index ad5d961c0..382505835 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2019 XGBoost contributors + * Copyright 2017-2020 XGBoost contributors */ #pragma once #include @@ -9,7 +9,6 @@ #include #include -#include #include #include #include diff --git a/src/common/observer.h b/src/common/observer.h index f150fa880..c047cc79b 100644 --- a/src/common/observer.h +++ b/src/common/observer.h @@ -1,11 +1,12 @@ /*! - * Copyright 2019 XGBoost contributors + * Copyright 2019-2020 XGBoost contributors * \file observer.h */ #ifndef XGBOOST_COMMON_OBSERVER_H_ #define XGBOOST_COMMON_OBSERVER_H_ #include +#include #include #include @@ -63,7 +64,8 @@ class TrainingObserver { } /*\brief Observe data hosted by `std::vector'. */ template - void Observe(std::vector const& h_vec, std::string name) const { + void Observe(std::vector const& h_vec, std::string name, + size_t n = std::numeric_limits::max()) const { if (XGBOOST_EXPECT(!observe_, true)) { return; } OBSERVER_PRINT << "Procedure: " << name << OBSERVER_ENDL; @@ -72,20 +74,25 @@ class TrainingObserver { if (i % 8 == 0) { OBSERVER_PRINT << OBSERVER_NEWLINE; } + if ((i + 1) == n) { + break; + } } OBSERVER_PRINT << OBSERVER_ENDL; } /*\brief Observe data hosted by `HostDeviceVector'. */ template - void Observe(HostDeviceVector const& vec, std::string name) const { + void Observe(HostDeviceVector const& vec, std::string name, + size_t n = std::numeric_limits::max()) const { if (XGBOOST_EXPECT(!observe_, true)) { return; } auto const& h_vec = vec.HostVector(); - this->Observe(h_vec, name); + this->Observe(h_vec, name, n); } template - void Observe(HostDeviceVector* vec, std::string name) const { + void Observe(HostDeviceVector* vec, std::string name, + size_t n = std::numeric_limits::max()) const { if (XGBOOST_EXPECT(!observe_, true)) { return; } - this->Observe(*vec, name); + this->Observe(*vec, name, n); } /*\brief Observe objects with `XGBoostParamer' type. */ diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index db45dae42..9724e37b5 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2014 by Contributors + * Copyright 2014-2020 by Contributors * \file gblinear.cc * \brief Implementation of Linear booster, with L1/L2 regularization: Elastic Net * the update rule is parallel coordinate descent (shotgun) @@ -239,7 +239,7 @@ class GBLinear : public GradientBooster { void PredictBatchInternal(DMatrix *p_fmat, std::vector *out_preds) { monitor_.Start("PredictBatchInternal"); - model_.LazyInitModel(); + model_.LazyInitModel(); std::vector &preds = *out_preds; const auto& base_margin = p_fmat->Info().base_margin_.ConstHostVector(); // start collecting the prediction @@ -250,6 +250,9 @@ class GBLinear : public GradientBooster { // k is number of group // parallel over local batch const auto nsize = static_cast(batch.Size()); + if (base_margin.size() != 0) { + CHECK_EQ(base_margin.size(), nsize * ngroup); + } #pragma omp parallel for schedule(static) for (omp_ulong i = 0; i < nsize; ++i) { const size_t ridx = batch.base_rowid + i; diff --git a/src/learner.cc b/src/learner.cc index 70ffceda6..10b7882c6 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2014-2019 by Contributors + * Copyright 2014-2020 by Contributors * \file learner.cc * \brief Implementation of learning algorithm. * \author Tianqi Chen @@ -691,7 +691,7 @@ class LearnerImpl : public Learner { return gbm_->DumpModel(fmap, with_stats, format); } - void UpdateOneIter(int iter, DMatrix* train) override { + void UpdateOneIter(int iter, std::shared_ptr train) override { monitor_.Start("UpdateOneIter"); TrainingObserver::Instance().Update(iter); this->Configure(); @@ -699,23 +699,23 @@ class LearnerImpl : public Learner { common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter); } this->CheckDataSplitMode(); - this->ValidateDMatrix(train); + this->ValidateDMatrix(train.get()); monitor_.Start("PredictRaw"); - this->PredictRaw(train, &preds_[train], true); + this->PredictRaw(train.get(), &preds_[train.get()], true); monitor_.Stop("PredictRaw"); - TrainingObserver::Instance().Observe(preds_[train], "Predictions"); + TrainingObserver::Instance().Observe(preds_[train.get()], "Predictions"); monitor_.Start("GetGradient"); - obj_->GetGradient(preds_[train], train->Info(), iter, &gpair_); + obj_->GetGradient(preds_[train.get()], train->Info(), iter, &gpair_); monitor_.Stop("GetGradient"); TrainingObserver::Instance().Observe(gpair_, "Gradients"); - gbm_->DoBoost(train, &gpair_, obj_.get()); + gbm_->DoBoost(train.get(), &gpair_, obj_.get()); monitor_.Stop("UpdateOneIter"); } - void BoostOneIter(int iter, DMatrix* train, + void BoostOneIter(int iter, std::shared_ptr train, HostDeviceVector* in_gpair) override { monitor_.Start("BoostOneIter"); this->Configure(); @@ -723,13 +723,13 @@ class LearnerImpl : public Learner { common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter); } this->CheckDataSplitMode(); - this->ValidateDMatrix(train); + this->ValidateDMatrix(train.get()); - gbm_->DoBoost(train, in_gpair); + gbm_->DoBoost(train.get(), in_gpair); monitor_.Stop("BoostOneIter"); } - std::string EvalOneIter(int iter, const std::vector& data_sets, + std::string EvalOneIter(int iter, const std::vector>& data_sets, const std::vector& data_names) override { monitor_.Start("EvalOneIter"); this->Configure(); @@ -741,9 +741,9 @@ class LearnerImpl : public Learner { metrics_.back()->Configure({cfg_.begin(), cfg_.end()}); } for (size_t i = 0; i < data_sets.size(); ++i) { - DMatrix * dmat = data_sets[i]; + DMatrix * dmat = data_sets[i].get(); this->ValidateDMatrix(dmat); - this->PredictRaw(data_sets[i], &preds_[dmat], false); + this->PredictRaw(dmat, &preds_[dmat], false); obj_->EvalTransform(&preds_[dmat]); for (auto& ev : metrics_) { os << '\t' << data_names[i] << '-' << ev->Name() << ':' @@ -805,7 +805,7 @@ class LearnerImpl : public Learner { return generic_parameters_; } - void Predict(DMatrix* data, bool output_margin, + void Predict(std::shared_ptr data, bool output_margin, HostDeviceVector* out_preds, unsigned ntree_limit, bool training, bool pred_leaf, bool pred_contribs, bool approx_contribs, @@ -816,14 +816,14 @@ class LearnerImpl : public Learner { this->Configure(); CHECK_LE(multiple_predictions, 1) << "Perform one kind of prediction at a time."; if (pred_contribs) { - gbm_->PredictContribution(data, &out_preds->HostVector(), ntree_limit, approx_contribs); + gbm_->PredictContribution(data.get(), &out_preds->HostVector(), ntree_limit, approx_contribs); } else if (pred_interactions) { - gbm_->PredictInteractionContributions(data, &out_preds->HostVector(), ntree_limit, + gbm_->PredictInteractionContributions(data.get(), &out_preds->HostVector(), ntree_limit, approx_contribs); } else if (pred_leaf) { - gbm_->PredictLeaf(data, &out_preds->HostVector(), ntree_limit); + gbm_->PredictLeaf(data.get(), &out_preds->HostVector(), ntree_limit); } else { - this->PredictRaw(data, out_preds, training, ntree_limit); + this->PredictRaw(data.get(), out_preds, training, ntree_limit); if (!output_margin) { obj_->PredTransform(out_preds); } diff --git a/tests/cpp/c_api/test_c_api.cc b/tests/cpp/c_api/test_c_api.cc index 02c0e6126..0385d819d 100644 --- a/tests/cpp/c_api/test_c_api.cc +++ b/tests/cpp/c_api/test_c_api.cc @@ -1,4 +1,6 @@ -// Copyright (c) 2019 by Contributors +/*! + * Copyright 2019-2020 XGBoost contributors + */ #include #include #include @@ -92,7 +94,7 @@ TEST(c_api, ConfigIO) { std::shared_ptr learner { Learner::Create(mat) }; BoosterHandle handle = learner.get(); - learner->UpdateOneIter(0, p_dmat.get()); + learner->UpdateOneIter(0, p_dmat); char const* out[1]; bst_ulong len {0}; @@ -127,7 +129,7 @@ TEST(c_api, JsonModelIO) { std::shared_ptr learner { Learner::Create(mat) }; - learner->UpdateOneIter(0, p_dmat.get()); + learner->UpdateOneIter(0, p_dmat); BoosterHandle handle = learner.get(); std::string modelfile_0 = tempdir.path + "/model_0.json"; diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index 6c8a6a3dc..0b92b4878 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -1,3 +1,6 @@ +/*! + * Copyright 2019-2020 XGBoost contributors + */ #include #include #include @@ -62,7 +65,7 @@ TEST(GBTree, ChoosePredictor) { auto learner = std::unique_ptr(Learner::Create({p_dmat})); learner->SetParams(Args{{"tree_method", "gpu_hist"}, {"gpu_id", "0"}}); for (size_t i = 0; i < 4; ++i) { - learner->UpdateOneIter(i, p_dmat.get()); + learner->UpdateOneIter(i, p_dmat); } ASSERT_TRUE(data.HostCanWrite()); dmlc::TemporaryDirectory tempdir; @@ -81,7 +84,7 @@ TEST(GBTree, ChoosePredictor) { } learner->SetParams(Args{{"tree_method", "gpu_hist"}, {"gpu_id", "0"}}); for (size_t i = 0; i < 4; ++i) { - learner->UpdateOneIter(i, p_dmat.get()); + learner->UpdateOneIter(i, p_dmat); } ASSERT_TRUE(data.HostCanWrite()); @@ -94,7 +97,7 @@ TEST(GBTree, ChoosePredictor) { learner = std::unique_ptr(Learner::Create({p_dmat})); learner->SetParams(Args{{"tree_method", "gpu_hist"}, {"gpu_id", "0"}}); for (size_t i = 0; i < 4; ++i) { - learner->UpdateOneIter(i, p_dmat.get()); + learner->UpdateOneIter(i, p_dmat); } // data is not pulled back into host ASSERT_FALSE(data.HostCanWrite()); @@ -196,13 +199,13 @@ TEST(Dart, Prediction) { learner->Configure(); for (size_t i = 0; i < 16; ++i) { - learner->UpdateOneIter(i, p_mat.get()); + learner->UpdateOneIter(i, p_mat); } HostDeviceVector predts_training; - learner->Predict(p_mat.get(), false, &predts_training, 0, true); + learner->Predict(p_mat, false, &predts_training, 0, true); HostDeviceVector predts_inference; - learner->Predict(p_mat.get(), false, &predts_inference, 0, false); + learner->Predict(p_mat, false, &predts_inference, 0, false); auto& h_predts_training = predts_training.ConstHostVector(); auto& h_predts_inference = predts_inference.ConstHostVector(); diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index c04bca61e..d21c634a9 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -1,4 +1,6 @@ -// Copyright by Contributors +/*! + * Copyright 2017-2020 XGBoost contributors + */ #include #include #include "helpers.h" @@ -79,13 +81,13 @@ TEST(Learner, CheckGroup) { std::vector> mat = {p_mat}; auto learner = std::unique_ptr(Learner::Create(mat)); learner->SetParams({Arg{"objective", "rank:pairwise"}}); - EXPECT_NO_THROW(learner->UpdateOneIter(0, p_mat.get())); + EXPECT_NO_THROW(learner->UpdateOneIter(0, p_mat)); group.resize(kNumGroups+1); group[3] = 4; group[4] = 1; p_mat->Info().SetInfo("group", group.data(), DataType::kUInt32, kNumGroups+1); - EXPECT_ANY_THROW(learner->UpdateOneIter(0, p_mat.get())); + EXPECT_ANY_THROW(learner->UpdateOneIter(0, p_mat)); delete pp_mat; } @@ -107,7 +109,7 @@ TEST(Learner, SLOW_CheckMultiBatch) { std::vector> mat{dmat}; auto learner = std::unique_ptr(Learner::Create(mat)); learner->SetParams(Args{{"objective", "binary:logistic"}}); - learner->UpdateOneIter(0, dmat.get()); + learner->UpdateOneIter(0, dmat); } TEST(Learner, Configuration) { @@ -142,6 +144,7 @@ TEST(Learner, JsonModelIO) { auto pp_dmat = CreateDMatrix(kRows, 10, 0); std::shared_ptr p_dmat {*pp_dmat}; p_dmat->Info().labels_.Resize(kRows); + CHECK_NE(p_dmat->Info().num_col_, 0); { std::unique_ptr learner { Learner::Create({p_dmat}) }; @@ -160,7 +163,7 @@ TEST(Learner, JsonModelIO) { { std::unique_ptr learner { Learner::Create({p_dmat}) }; for (int32_t iter = 0; iter < kIters; ++iter) { - learner->UpdateOneIter(iter, p_dmat.get()); + learner->UpdateOneIter(iter, p_dmat); } learner->SetAttr("best_score", "15.2"); @@ -197,20 +200,20 @@ TEST(Learner, GPUConfiguration) { std::unique_ptr learner {Learner::Create(mat)}; learner->SetParams({Arg{"booster", "gblinear"}, Arg{"updater", "gpu_coord_descent"}}); - learner->UpdateOneIter(0, p_dmat.get()); + learner->UpdateOneIter(0, p_dmat); ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0); } { std::unique_ptr learner {Learner::Create(mat)}; learner->SetParams({Arg{"tree_method", "gpu_hist"}}); - learner->UpdateOneIter(0, p_dmat.get()); + learner->UpdateOneIter(0, p_dmat); ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0); } { // with CPU algorithm std::unique_ptr learner {Learner::Create(mat)}; learner->SetParams({Arg{"tree_method", "hist"}}); - learner->UpdateOneIter(0, p_dmat.get()); + learner->UpdateOneIter(0, p_dmat); ASSERT_EQ(learner->GetGenericParameter().gpu_id, -1); } { @@ -218,7 +221,7 @@ TEST(Learner, GPUConfiguration) { std::unique_ptr learner {Learner::Create(mat)}; learner->SetParams({Arg{"tree_method", "hist"}, Arg{"gpu_id", "0"}}); - learner->UpdateOneIter(0, p_dmat.get()); + learner->UpdateOneIter(0, p_dmat); ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0); } { @@ -228,7 +231,7 @@ TEST(Learner, GPUConfiguration) { std::unique_ptr learner {Learner::Create(mat)}; learner->SetParams({Arg{"tree_method", "hist"}, Arg{"predictor", "gpu_predictor"}}); - learner->UpdateOneIter(0, p_dmat.get()); + learner->UpdateOneIter(0, p_dmat); ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0); } diff --git a/tests/cpp/test_serialization.cc b/tests/cpp/test_serialization.cc index f2bc39267..a6eb836c0 100644 --- a/tests/cpp/test_serialization.cc +++ b/tests/cpp/test_serialization.cc @@ -1,3 +1,4 @@ +// Copyright (c) 2019-2020 by Contributors #include #include #include @@ -24,12 +25,13 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr std::vector dumped_0; std::string model_at_kiter; + // Train for kIters. { std::unique_ptr fo(dmlc::Stream::Create(fname.c_str(), "w")); std::unique_ptr learner {Learner::Create({p_dmat})}; learner->SetParams(args); for (int32_t iter = 0; iter < kIters; ++iter) { - learner->UpdateOneIter(iter, p_dmat.get()); + learner->UpdateOneIter(iter, p_dmat); } dumped_0 = learner->DumpModel(fmap, true, "json"); learner->Save(fo.get()); @@ -38,6 +40,7 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr learner->Save(&mem_out); } + // Assert dumped model is same after loading std::vector dumped_1; { std::unique_ptr fi(dmlc::Stream::Create(fname.c_str(), "r")); @@ -73,7 +76,7 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr } for (int32_t iter = kIters; iter < 2 * kIters; ++iter) { - learner->UpdateOneIter(iter, p_dmat.get()); + learner->UpdateOneIter(iter, p_dmat); } common::MemoryBufferStream fo(&continued_model); learner->Save(&fo); @@ -84,7 +87,7 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr std::unique_ptr learner{Learner::Create({p_dmat})}; learner->SetParams(args); for (int32_t iter = 0; iter < 2 * kIters; ++iter) { - learner->UpdateOneIter(iter, p_dmat.get()); + learner->UpdateOneIter(iter, p_dmat); // Verify model is same at the same iteration during two training // sessions. @@ -98,6 +101,7 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr common::MemoryBufferStream fo(&model_at_2kiter); learner->Save(&fo); } + Json m_0 = Json::Load(StringView{continued_model.c_str(), continued_model.size()}); Json m_1 = Json::Load(StringView{model_at_2kiter.c_str(), model_at_2kiter.size()}); ASSERT_EQ(m_0, m_1); @@ -127,7 +131,7 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr } for (int32_t iter = kIters; iter < 2 * kIters; ++iter) { - learner->UpdateOneIter(iter, p_dmat.get()); + learner->UpdateOneIter(iter, p_dmat); } serialised_model_tmp = std::string{}; common::MemoryBufferStream fo(&serialised_model_tmp); @@ -306,7 +310,7 @@ TEST_F(SerializationTest, ConfigurationCount) { learner->SetParam("enable_experimental_json_serialization", "1"); for (size_t i = 0; i < 10; ++i) { - learner->UpdateOneIter(i, p_dmat.get()); + learner->UpdateOneIter(i, p_dmat); } common::MemoryBufferStream fo(&model_str); learner->Save(&fo); @@ -317,7 +321,7 @@ TEST_F(SerializationTest, ConfigurationCount) { auto learner = std::unique_ptr(Learner::Create(mat)); learner->Load(&fi); for (size_t i = 0; i < 10; ++i) { - learner->UpdateOneIter(i, p_dmat.get()); + learner->UpdateOneIter(i, p_dmat); } } diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 6cb0aad26..d3ed0d8e3 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2019 XGBoost contributors + * Copyright 2017-2020 XGBoost contributors */ #include #include @@ -387,6 +387,7 @@ TEST(GpuHist, UniformSampling) { constexpr size_t kRows = 4096; constexpr size_t kCols = 2; constexpr float kSubsample = 0.99; + common::GlobalRandom().seed(1994); // Create an in-memory DMatrix. std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); @@ -415,6 +416,7 @@ TEST(GpuHist, GradientBasedSampling) { constexpr size_t kRows = 4096; constexpr size_t kCols = 2; constexpr float kSubsample = 0.99; + common::GlobalRandom().seed(1994); // Create an in-memory DMatrix. std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); @@ -478,6 +480,7 @@ TEST(GpuHist, ExternalMemoryWithSampling) { constexpr size_t kPageSize = 1024; constexpr float kSubsample = 0.5; const std::string kSamplingMethod = "gradient_based"; + common::GlobalRandom().seed(0); // Create an in-memory DMatrix. std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); @@ -503,7 +506,7 @@ TEST(GpuHist, ExternalMemoryWithSampling) { auto preds_h = preds.ConstHostVector(); auto preds_ext_h = preds_ext.ConstHostVector(); for (int i = 0; i < kRows; i++) { - EXPECT_NEAR(preds_h[i], preds_ext_h[i], 3e-3); + EXPECT_NEAR(preds_h[i], preds_ext_h[i], 2e-3); } }