More tests for column split and vertical federated learning (#8985)
Added some more tests for the learner and fit_stump, for both column-wise distributed learning and vertical federated learning. Also moved the `IsRowSplit` and `IsColumnSplit` methods from the `DMatrix` to the `MetaInfo` since in some places we only have access to the `MetaInfo`. Added a new convenience method `IsVerticalFederatedLearning`. Some refactoring of the testing fixtures.
This commit is contained in:
parent
401ce5cf5e
commit
ff26cd3212
@ -180,6 +180,22 @@ class MetaInfo {
|
||||
*/
|
||||
void SynchronizeNumberOfColumns();
|
||||
|
||||
/*! \brief Whether the data is split row-wise. */
|
||||
bool IsRowSplit() const {
|
||||
return data_split_mode == DataSplitMode::kRow;
|
||||
}
|
||||
|
||||
/*! \brief Whether the data is split column-wise. */
|
||||
bool IsColumnSplit() const {
|
||||
return data_split_mode == DataSplitMode::kCol;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief A convenient method to check if we are doing vertical federated learning, which requires
|
||||
* some special processing.
|
||||
*/
|
||||
bool IsVerticalFederated() const;
|
||||
|
||||
private:
|
||||
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
|
||||
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
|
||||
@ -542,16 +558,6 @@ class DMatrix {
|
||||
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
|
||||
}
|
||||
|
||||
/*! \brief Whether the data is split row-wise. */
|
||||
bool IsRowSplit() const {
|
||||
return Info().data_split_mode == DataSplitMode::kRow;
|
||||
}
|
||||
|
||||
/*! \brief Whether the data is split column-wise. */
|
||||
bool IsColumnSplit() const {
|
||||
return Info().data_split_mode == DataSplitMode::kCol;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Load DMatrix from URI.
|
||||
* \param uri The URI of input.
|
||||
|
||||
@ -46,7 +46,7 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
|
||||
if (!use_sorted) {
|
||||
HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
|
||||
HostSketchContainer::UseGroup(info),
|
||||
m->IsColumnSplit(), n_threads);
|
||||
m->Info().IsColumnSplit(), n_threads);
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
container.PushRowPage(page, info, hessian);
|
||||
}
|
||||
@ -54,7 +54,7 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
|
||||
} else {
|
||||
SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
|
||||
HostSketchContainer::UseGroup(info),
|
||||
m->IsColumnSplit(), n_threads};
|
||||
m->Info().IsColumnSplit(), n_threads};
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||
container.PushColPage(page, info, hessian);
|
||||
}
|
||||
|
||||
@ -704,7 +704,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
|
||||
}
|
||||
|
||||
void MetaInfo::SynchronizeNumberOfColumns() {
|
||||
if (collective::IsFederated() && data_split_mode == DataSplitMode::kCol) {
|
||||
if (IsVerticalFederated()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
|
||||
} else {
|
||||
collective::Allreduce<collective::Operation::kMax>(&num_col_, 1);
|
||||
@ -770,6 +770,10 @@ void MetaInfo::Validate(std::int32_t device) const {
|
||||
void MetaInfo::SetInfoFromCUDA(Context const&, StringView, Json) { common::AssertGPUSupport(); }
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
bool MetaInfo::IsVerticalFederated() const {
|
||||
return collective::IsFederated() && IsColumnSplit();
|
||||
}
|
||||
|
||||
using DMatrixThreadLocal =
|
||||
dmlc::ThreadLocalStore<std::map<DMatrix const *, XGBAPIThreadLocalEntry>>;
|
||||
|
||||
|
||||
@ -213,7 +213,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
SyncFeatureType(&h_ft);
|
||||
p_sketch.reset(new common::HostSketchContainer{
|
||||
batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(),
|
||||
proxy->IsColumnSplit(), ctx_.Threads()});
|
||||
proxy->Info().IsColumnSplit(), ctx_.Threads()});
|
||||
}
|
||||
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
||||
proxy->Info().num_nonzero_ = batch_nnz[i];
|
||||
|
||||
@ -74,7 +74,7 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
|
||||
}
|
||||
|
||||
void SimpleDMatrix::ReindexFeatures() {
|
||||
if (collective::IsFederated() && info_.data_split_mode == DataSplitMode::kCol) {
|
||||
if (info_.IsVerticalFederated()) {
|
||||
std::vector<uint64_t> buffer(collective::GetWorldSize());
|
||||
buffer[collective::GetRank()] = info_.num_col_;
|
||||
collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t));
|
||||
|
||||
@ -860,9 +860,9 @@ class LearnerConfiguration : public Learner {
|
||||
|
||||
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
|
||||
// Special handling for vertical federated learning.
|
||||
if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) {
|
||||
if (info.IsVerticalFederated()) {
|
||||
// We assume labels are only available on worker 0, so the estimation is calculated there
|
||||
// and added to other workers.
|
||||
// and broadcast to other workers.
|
||||
if (collective::GetRank() == 0) {
|
||||
UsePtr(obj_)->InitEstimation(info, base_score);
|
||||
collective::Broadcast(base_score->Data()->HostPointer(),
|
||||
@ -1487,7 +1487,7 @@ class LearnerImpl : public LearnerIO {
|
||||
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
|
||||
HostDeviceVector<GradientPair>* out_gpair) {
|
||||
// Special handling for vertical federated learning.
|
||||
if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) {
|
||||
if (info.IsVerticalFederated()) {
|
||||
// We assume labels are only available on worker 0, so the gradients are calculated there
|
||||
// and broadcast to other workers.
|
||||
if (collective::GetRank() == 0) {
|
||||
|
||||
@ -605,7 +605,7 @@ class CPUPredictor : public Predictor {
|
||||
protected:
|
||||
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
|
||||
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const {
|
||||
if (p_fmat->IsColumnSplit()) {
|
||||
if (p_fmat->Info().IsColumnSplit()) {
|
||||
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
|
||||
helper.PredictDMatrix(p_fmat, out_preds);
|
||||
return;
|
||||
|
||||
@ -45,8 +45,7 @@ void FitStump(Context const* ctx, MetaInfo const& info,
|
||||
}
|
||||
CHECK(h_sum.CContiguous());
|
||||
|
||||
// In vertical federated learning, only worker 0 needs to call this, no need to do an allreduce.
|
||||
if (!collective::IsFederated() || info.data_split_mode != DataSplitMode::kCol) {
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(
|
||||
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
|
||||
}
|
||||
|
||||
@ -449,7 +449,7 @@ class HistEvaluator {
|
||||
param_{param},
|
||||
column_sampler_{std::move(sampler)},
|
||||
tree_evaluator_{*param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId},
|
||||
is_col_split_{info.data_split_mode == DataSplitMode::kCol} {
|
||||
is_col_split_{info.IsColumnSplit()} {
|
||||
interaction_constraints_.Configure(*param, info.num_col_);
|
||||
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
|
||||
param_->colsample_bynode, param_->colsample_bylevel,
|
||||
|
||||
@ -72,12 +72,13 @@ class GloablApproxBuilder {
|
||||
} else {
|
||||
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
||||
}
|
||||
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, p_fmat->IsColumnSplit());
|
||||
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid,
|
||||
p_fmat->Info().IsColumnSplit());
|
||||
n_batches_++;
|
||||
}
|
||||
|
||||
histogram_builder_.Reset(n_total_bins, BatchSpec(*param_, hess), ctx_->Threads(), n_batches_,
|
||||
collective::IsDistributed(), p_fmat->IsColumnSplit());
|
||||
collective::IsDistributed(), p_fmat->Info().IsColumnSplit());
|
||||
monitor_->Stop(__func__);
|
||||
}
|
||||
|
||||
@ -91,7 +92,7 @@ class GloablApproxBuilder {
|
||||
for (auto const &g : gpair) {
|
||||
root_sum.Add(g);
|
||||
}
|
||||
if (p_fmat->IsRowSplit()) {
|
||||
if (p_fmat->Info().IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2);
|
||||
}
|
||||
std::vector<CPUExpandEntry> nodes{best};
|
||||
|
||||
@ -158,7 +158,7 @@ class MultiTargetHistBuilder {
|
||||
} else {
|
||||
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
||||
}
|
||||
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->IsColumnSplit());
|
||||
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->Info().IsColumnSplit());
|
||||
page_id++;
|
||||
}
|
||||
|
||||
@ -167,7 +167,7 @@ class MultiTargetHistBuilder {
|
||||
for (std::size_t i = 0; i < n_targets; ++i) {
|
||||
histogram_builder_.emplace_back();
|
||||
histogram_builder_.back().Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
||||
collective::IsDistributed(), p_fmat->IsColumnSplit());
|
||||
collective::IsDistributed(), p_fmat->Info().IsColumnSplit());
|
||||
}
|
||||
|
||||
evaluator_ = std::make_unique<HistMultiEvaluator>(ctx_, p_fmat->Info(), param_, col_sampler_);
|
||||
@ -388,11 +388,12 @@ class HistBuilder {
|
||||
} else {
|
||||
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
||||
}
|
||||
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, fmat->IsColumnSplit());
|
||||
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid,
|
||||
fmat->Info().IsColumnSplit());
|
||||
++page_id;
|
||||
}
|
||||
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
||||
collective::IsDistributed(), fmat->IsColumnSplit());
|
||||
collective::IsDistributed(), fmat->Info().IsColumnSplit());
|
||||
evaluator_ = std::make_unique<HistEvaluator<CPUExpandEntry>>(ctx_, this->param_, fmat->Info(),
|
||||
col_sampler_);
|
||||
p_last_tree_ = p_tree;
|
||||
|
||||
@ -191,15 +191,9 @@ double GetMultiMetricEval(xgboost::Metric* metric,
|
||||
}
|
||||
|
||||
namespace xgboost {
|
||||
bool IsNear(std::vector<xgboost::bst_float>::const_iterator _beg1,
|
||||
std::vector<xgboost::bst_float>::const_iterator _end1,
|
||||
std::vector<xgboost::bst_float>::const_iterator _beg2) {
|
||||
for (auto iter1 = _beg1, iter2 = _beg2; iter1 != _end1; ++iter1, ++iter2) {
|
||||
if (std::abs(*iter1 - *iter2) > xgboost::kRtEps){
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
||||
float GetBaseScore(Json const &config) {
|
||||
return std::stof(get<String const>(config["learner"]["learner_model_param"]["base_score"]));
|
||||
}
|
||||
|
||||
SimpleLCG::StateType SimpleLCG::operator()() {
|
||||
|
||||
@ -101,9 +101,8 @@ double GetMultiMetricEval(xgboost::Metric* metric,
|
||||
std::vector<xgboost::bst_uint> groups = {});
|
||||
|
||||
namespace xgboost {
|
||||
bool IsNear(std::vector<xgboost::bst_float>::const_iterator _beg1,
|
||||
std::vector<xgboost::bst_float>::const_iterator _end1,
|
||||
std::vector<xgboost::bst_float>::const_iterator _beg2);
|
||||
|
||||
float GetBaseScore(Json const &config);
|
||||
|
||||
/*!
|
||||
* \brief Linear congruential generator.
|
||||
|
||||
@ -52,18 +52,33 @@ class BaseFederatedTest : public ::testing::Test {
|
||||
server_thread_->join();
|
||||
}
|
||||
|
||||
void InitCommunicator(int rank) {
|
||||
Json config{JsonObject()};
|
||||
config["xgboost_communicator"] = String("federated");
|
||||
config["federated_server_address"] = String(server_address_);
|
||||
config["federated_world_size"] = kWorldSize;
|
||||
config["federated_rank"] = rank;
|
||||
xgboost::collective::Init(config);
|
||||
}
|
||||
|
||||
static int const kWorldSize{3};
|
||||
std::string server_address_;
|
||||
std::unique_ptr<std::thread> server_thread_;
|
||||
std::unique_ptr<grpc::Server> server_;
|
||||
};
|
||||
|
||||
template <typename Function, typename... Args>
|
||||
void RunWithFederatedCommunicator(int32_t world_size, std::string const& server_address,
|
||||
Function&& function, Args&&... args) {
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < world_size; rank++) {
|
||||
threads.emplace_back([&, rank]() {
|
||||
Json config{JsonObject()};
|
||||
config["xgboost_communicator"] = String("federated");
|
||||
config["federated_server_address"] = String(server_address);
|
||||
config["federated_world_size"] = world_size;
|
||||
config["federated_rank"] = rank;
|
||||
xgboost::collective::Init(config);
|
||||
|
||||
std::forward<Function>(function)(std::forward<Args>(args)...);
|
||||
|
||||
xgboost::collective::Finalize();
|
||||
});
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
/*!
|
||||
* Copyright 2023 XGBoost contributors
|
||||
*/
|
||||
#include <dmlc/parameter.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
|
||||
#include "../../../plugin/federated/federated_server.h"
|
||||
@ -17,10 +14,10 @@
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
class FederatedDataTest : public BaseFederatedTest {
|
||||
public:
|
||||
void VerifyLoadUri(int rank) {
|
||||
InitCommunicator(rank);
|
||||
class FederatedDataTest : public BaseFederatedTest {};
|
||||
|
||||
void VerifyLoadUri() {
|
||||
auto const rank = collective::GetRank();
|
||||
|
||||
size_t constexpr kRows{16};
|
||||
size_t const kCols = 8 + rank;
|
||||
@ -33,7 +30,7 @@ class FederatedDataTest : public BaseFederatedTest {
|
||||
std::string uri = path + "?format=csv";
|
||||
dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kCol));
|
||||
|
||||
ASSERT_EQ(dmat->Info().num_col_, 8 * kWorldSize + 3);
|
||||
ASSERT_EQ(dmat->Info().num_col_, 8 * collective::GetWorldSize() + 3);
|
||||
ASSERT_EQ(dmat->Info().num_row_, kRows);
|
||||
|
||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||
@ -48,18 +45,9 @@ class FederatedDataTest : public BaseFederatedTest {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
xgboost::collective::Finalize();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
TEST_F(FederatedDataTest, LoadUri) {
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||
threads.emplace_back(&FederatedDataTest_LoadUri_Test::VerifyLoadUri, this, rank);
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyLoadUri);
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
78
tests/cpp/plugin/test_federated_learner.cc
Normal file
78
tests/cpp/plugin/test_federated_learner.cc
Normal file
@ -0,0 +1,78 @@
|
||||
/*!
|
||||
* Copyright 2023 XGBoost contributors
|
||||
*/
|
||||
#include <dmlc/parameter.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/objective.h>
|
||||
|
||||
#include "../../../plugin/federated/federated_server.h"
|
||||
#include "../../../src/collective/communicator-inl.h"
|
||||
#include "../helpers.h"
|
||||
#include "helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
class FederatedLearnerTest : public BaseFederatedTest {
|
||||
protected:
|
||||
static auto constexpr kRows{16};
|
||||
static auto constexpr kCols{16};
|
||||
};
|
||||
|
||||
void VerifyBaseScore(size_t rows, size_t cols, float expected_base_score) {
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
std::shared_ptr<DMatrix> Xy_{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)};
|
||||
std::shared_ptr<DMatrix> sliced{Xy_->SliceCol(world_size, rank)};
|
||||
std::unique_ptr<Learner> learner{Learner::Create({sliced})};
|
||||
learner->SetParam("tree_method", "approx");
|
||||
learner->SetParam("objective", "binary:logistic");
|
||||
learner->UpdateOneIter(0, sliced);
|
||||
Json config{Object{}};
|
||||
learner->SaveConfig(&config);
|
||||
auto base_score = GetBaseScore(config);
|
||||
ASSERT_EQ(base_score, expected_base_score);
|
||||
}
|
||||
|
||||
void VerifyModel(size_t rows, size_t cols, Json const& expected_model) {
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
std::shared_ptr<DMatrix> Xy_{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)};
|
||||
std::shared_ptr<DMatrix> sliced{Xy_->SliceCol(world_size, rank)};
|
||||
std::unique_ptr<Learner> learner{Learner::Create({sliced})};
|
||||
learner->SetParam("tree_method", "approx");
|
||||
learner->SetParam("objective", "binary:logistic");
|
||||
learner->UpdateOneIter(0, sliced);
|
||||
Json model{Object{}};
|
||||
learner->SaveModel(&model);
|
||||
ASSERT_EQ(model, expected_model);
|
||||
}
|
||||
|
||||
TEST_F(FederatedLearnerTest, BaseScore) {
|
||||
std::shared_ptr<DMatrix> Xy_{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)};
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
||||
learner->SetParam("tree_method", "approx");
|
||||
learner->SetParam("objective", "binary:logistic");
|
||||
learner->UpdateOneIter(0, Xy_);
|
||||
Json config{Object{}};
|
||||
learner->SaveConfig(&config);
|
||||
auto base_score = GetBaseScore(config);
|
||||
ASSERT_NE(base_score, ObjFunction::DefaultBaseScore());
|
||||
|
||||
RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyBaseScore, kRows, kCols,
|
||||
base_score);
|
||||
}
|
||||
|
||||
TEST_F(FederatedLearnerTest, Model) {
|
||||
std::shared_ptr<DMatrix> Xy_{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)};
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
||||
learner->SetParam("tree_method", "approx");
|
||||
learner->SetParam("objective", "binary:logistic");
|
||||
learner->UpdateOneIter(0, Xy_);
|
||||
Json model{Object{}};
|
||||
learner->SaveModel(&model);
|
||||
|
||||
RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyModel, kRows, kCols,
|
||||
std::cref(model));
|
||||
}
|
||||
} // namespace xgboost
|
||||
@ -460,10 +460,6 @@ class InitBaseScore : public ::testing::Test {
|
||||
|
||||
void SetUp() override { Xy_ = RandomDataGenerator{10, Cols(), 0}.GenerateDMatrix(true); }
|
||||
|
||||
static float GetBaseScore(Json const &config) {
|
||||
return std::stof(get<String const>(config["learner"]["learner_model_param"]["base_score"]));
|
||||
}
|
||||
|
||||
public:
|
||||
void TestUpdateConfig() {
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
||||
@ -611,4 +607,32 @@ TEST_F(InitBaseScore, InitAfterLoad) { this->TestInitAfterLoad(); }
|
||||
TEST_F(InitBaseScore, InitWithPredict) { this->TestInitWithPredt(); }
|
||||
|
||||
TEST_F(InitBaseScore, UpdateProcess) { this->TestUpdateProcess(); }
|
||||
|
||||
void TestColumnSplitBaseScore(std::shared_ptr<DMatrix> Xy_, float expected_base_score) {
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
std::shared_ptr<DMatrix> sliced{Xy_->SliceCol(world_size, rank)};
|
||||
std::unique_ptr<Learner> learner{Learner::Create({sliced})};
|
||||
learner->SetParam("tree_method", "approx");
|
||||
learner->SetParam("objective", "binary:logistic");
|
||||
learner->UpdateOneIter(0, sliced);
|
||||
Json config{Object{}};
|
||||
learner->SaveConfig(&config);
|
||||
auto base_score = GetBaseScore(config);
|
||||
ASSERT_EQ(base_score, expected_base_score);
|
||||
}
|
||||
|
||||
TEST_F(InitBaseScore, ColumnSplit) {
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
||||
learner->SetParam("tree_method", "approx");
|
||||
learner->SetParam("objective", "binary:logistic");
|
||||
learner->UpdateOneIter(0, Xy_);
|
||||
Json config{Object{}};
|
||||
learner->SaveConfig(&config);
|
||||
auto base_score = GetBaseScore(config);
|
||||
ASSERT_NE(base_score, ObjFunction::DefaultBaseScore());
|
||||
|
||||
auto constexpr kWorldSize{3};
|
||||
RunWithInMemoryCommunicator(kWorldSize, &TestColumnSplitBaseScore, Xy_, base_score);
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -6,11 +6,12 @@
|
||||
|
||||
#include "../../src/common/linalg_op.h"
|
||||
#include "../../src/tree/fit_stump.h"
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
namespace {
|
||||
void TestFitStump(Context const *ctx) {
|
||||
void TestFitStump(Context const *ctx, DataSplitMode split = DataSplitMode::kRow) {
|
||||
std::size_t constexpr kRows = 16, kTargets = 2;
|
||||
HostDeviceVector<GradientPair> gpair;
|
||||
auto &h_gpair = gpair.HostVector();
|
||||
@ -22,6 +23,7 @@ void TestFitStump(Context const *ctx) {
|
||||
}
|
||||
linalg::Vector<float> out;
|
||||
MetaInfo info;
|
||||
info.data_split_mode = split;
|
||||
FitStump(ctx, info, gpair, kTargets, &out);
|
||||
auto h_out = out.HostView();
|
||||
for (auto it = linalg::cbegin(h_out); it != linalg::cend(h_out); ++it) {
|
||||
@ -45,5 +47,12 @@ TEST(InitEstimation, GPUFitStump) {
|
||||
TestFitStump(&ctx);
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
|
||||
TEST(InitEstimation, FitStumpColumnSplit) {
|
||||
Context ctx;
|
||||
auto constexpr kWorldSize{3};
|
||||
RunWithInMemoryCommunicator(kWorldSize, &TestFitStump, &ctx, DataSplitMode::kCol);
|
||||
}
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user