More tests for column split and vertical federated learning (#8985)
Added some more tests for the learner and fit_stump, for both column-wise distributed learning and vertical federated learning. Also moved the `IsRowSplit` and `IsColumnSplit` methods from the `DMatrix` to the `MetaInfo` since in some places we only have access to the `MetaInfo`. Added a new convenience method `IsVerticalFederatedLearning`. Some refactoring of the testing fixtures.
This commit is contained in:
parent
401ce5cf5e
commit
ff26cd3212
@ -180,6 +180,22 @@ class MetaInfo {
|
|||||||
*/
|
*/
|
||||||
void SynchronizeNumberOfColumns();
|
void SynchronizeNumberOfColumns();
|
||||||
|
|
||||||
|
/*! \brief Whether the data is split row-wise. */
|
||||||
|
bool IsRowSplit() const {
|
||||||
|
return data_split_mode == DataSplitMode::kRow;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*! \brief Whether the data is split column-wise. */
|
||||||
|
bool IsColumnSplit() const {
|
||||||
|
return data_split_mode == DataSplitMode::kCol;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief A convenient method to check if we are doing vertical federated learning, which requires
|
||||||
|
* some special processing.
|
||||||
|
*/
|
||||||
|
bool IsVerticalFederated() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
|
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
|
||||||
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
|
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
|
||||||
@ -542,16 +558,6 @@ class DMatrix {
|
|||||||
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
|
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief Whether the data is split row-wise. */
|
|
||||||
bool IsRowSplit() const {
|
|
||||||
return Info().data_split_mode == DataSplitMode::kRow;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*! \brief Whether the data is split column-wise. */
|
|
||||||
bool IsColumnSplit() const {
|
|
||||||
return Info().data_split_mode == DataSplitMode::kCol;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Load DMatrix from URI.
|
* \brief Load DMatrix from URI.
|
||||||
* \param uri The URI of input.
|
* \param uri The URI of input.
|
||||||
|
|||||||
@ -46,7 +46,7 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
|
|||||||
if (!use_sorted) {
|
if (!use_sorted) {
|
||||||
HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
|
HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
|
||||||
HostSketchContainer::UseGroup(info),
|
HostSketchContainer::UseGroup(info),
|
||||||
m->IsColumnSplit(), n_threads);
|
m->Info().IsColumnSplit(), n_threads);
|
||||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||||
container.PushRowPage(page, info, hessian);
|
container.PushRowPage(page, info, hessian);
|
||||||
}
|
}
|
||||||
@ -54,7 +54,7 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
|
|||||||
} else {
|
} else {
|
||||||
SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
|
SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
|
||||||
HostSketchContainer::UseGroup(info),
|
HostSketchContainer::UseGroup(info),
|
||||||
m->IsColumnSplit(), n_threads};
|
m->Info().IsColumnSplit(), n_threads};
|
||||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||||
container.PushColPage(page, info, hessian);
|
container.PushColPage(page, info, hessian);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -704,7 +704,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
|
|||||||
}
|
}
|
||||||
|
|
||||||
void MetaInfo::SynchronizeNumberOfColumns() {
|
void MetaInfo::SynchronizeNumberOfColumns() {
|
||||||
if (collective::IsFederated() && data_split_mode == DataSplitMode::kCol) {
|
if (IsVerticalFederated()) {
|
||||||
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
|
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
|
||||||
} else {
|
} else {
|
||||||
collective::Allreduce<collective::Operation::kMax>(&num_col_, 1);
|
collective::Allreduce<collective::Operation::kMax>(&num_col_, 1);
|
||||||
@ -770,6 +770,10 @@ void MetaInfo::Validate(std::int32_t device) const {
|
|||||||
void MetaInfo::SetInfoFromCUDA(Context const&, StringView, Json) { common::AssertGPUSupport(); }
|
void MetaInfo::SetInfoFromCUDA(Context const&, StringView, Json) { common::AssertGPUSupport(); }
|
||||||
#endif // !defined(XGBOOST_USE_CUDA)
|
#endif // !defined(XGBOOST_USE_CUDA)
|
||||||
|
|
||||||
|
bool MetaInfo::IsVerticalFederated() const {
|
||||||
|
return collective::IsFederated() && IsColumnSplit();
|
||||||
|
}
|
||||||
|
|
||||||
using DMatrixThreadLocal =
|
using DMatrixThreadLocal =
|
||||||
dmlc::ThreadLocalStore<std::map<DMatrix const *, XGBAPIThreadLocalEntry>>;
|
dmlc::ThreadLocalStore<std::map<DMatrix const *, XGBAPIThreadLocalEntry>>;
|
||||||
|
|
||||||
|
|||||||
@ -213,7 +213,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
|||||||
SyncFeatureType(&h_ft);
|
SyncFeatureType(&h_ft);
|
||||||
p_sketch.reset(new common::HostSketchContainer{
|
p_sketch.reset(new common::HostSketchContainer{
|
||||||
batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(),
|
batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(),
|
||||||
proxy->IsColumnSplit(), ctx_.Threads()});
|
proxy->Info().IsColumnSplit(), ctx_.Threads()});
|
||||||
}
|
}
|
||||||
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
||||||
proxy->Info().num_nonzero_ = batch_nnz[i];
|
proxy->Info().num_nonzero_ = batch_nnz[i];
|
||||||
|
|||||||
@ -74,7 +74,7 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SimpleDMatrix::ReindexFeatures() {
|
void SimpleDMatrix::ReindexFeatures() {
|
||||||
if (collective::IsFederated() && info_.data_split_mode == DataSplitMode::kCol) {
|
if (info_.IsVerticalFederated()) {
|
||||||
std::vector<uint64_t> buffer(collective::GetWorldSize());
|
std::vector<uint64_t> buffer(collective::GetWorldSize());
|
||||||
buffer[collective::GetRank()] = info_.num_col_;
|
buffer[collective::GetRank()] = info_.num_col_;
|
||||||
collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t));
|
collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t));
|
||||||
|
|||||||
@ -860,9 +860,9 @@ class LearnerConfiguration : public Learner {
|
|||||||
|
|
||||||
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
|
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
|
||||||
// Special handling for vertical federated learning.
|
// Special handling for vertical federated learning.
|
||||||
if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) {
|
if (info.IsVerticalFederated()) {
|
||||||
// We assume labels are only available on worker 0, so the estimation is calculated there
|
// We assume labels are only available on worker 0, so the estimation is calculated there
|
||||||
// and added to other workers.
|
// and broadcast to other workers.
|
||||||
if (collective::GetRank() == 0) {
|
if (collective::GetRank() == 0) {
|
||||||
UsePtr(obj_)->InitEstimation(info, base_score);
|
UsePtr(obj_)->InitEstimation(info, base_score);
|
||||||
collective::Broadcast(base_score->Data()->HostPointer(),
|
collective::Broadcast(base_score->Data()->HostPointer(),
|
||||||
@ -1487,7 +1487,7 @@ class LearnerImpl : public LearnerIO {
|
|||||||
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
|
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
|
||||||
HostDeviceVector<GradientPair>* out_gpair) {
|
HostDeviceVector<GradientPair>* out_gpair) {
|
||||||
// Special handling for vertical federated learning.
|
// Special handling for vertical federated learning.
|
||||||
if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) {
|
if (info.IsVerticalFederated()) {
|
||||||
// We assume labels are only available on worker 0, so the gradients are calculated there
|
// We assume labels are only available on worker 0, so the gradients are calculated there
|
||||||
// and broadcast to other workers.
|
// and broadcast to other workers.
|
||||||
if (collective::GetRank() == 0) {
|
if (collective::GetRank() == 0) {
|
||||||
|
|||||||
@ -605,7 +605,7 @@ class CPUPredictor : public Predictor {
|
|||||||
protected:
|
protected:
|
||||||
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
|
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
|
||||||
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const {
|
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const {
|
||||||
if (p_fmat->IsColumnSplit()) {
|
if (p_fmat->Info().IsColumnSplit()) {
|
||||||
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
|
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
|
||||||
helper.PredictDMatrix(p_fmat, out_preds);
|
helper.PredictDMatrix(p_fmat, out_preds);
|
||||||
return;
|
return;
|
||||||
|
|||||||
@ -45,8 +45,7 @@ void FitStump(Context const* ctx, MetaInfo const& info,
|
|||||||
}
|
}
|
||||||
CHECK(h_sum.CContiguous());
|
CHECK(h_sum.CContiguous());
|
||||||
|
|
||||||
// In vertical federated learning, only worker 0 needs to call this, no need to do an allreduce.
|
if (info.IsRowSplit()) {
|
||||||
if (!collective::IsFederated() || info.data_split_mode != DataSplitMode::kCol) {
|
|
||||||
collective::Allreduce<collective::Operation::kSum>(
|
collective::Allreduce<collective::Operation::kSum>(
|
||||||
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
|
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -449,7 +449,7 @@ class HistEvaluator {
|
|||||||
param_{param},
|
param_{param},
|
||||||
column_sampler_{std::move(sampler)},
|
column_sampler_{std::move(sampler)},
|
||||||
tree_evaluator_{*param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId},
|
tree_evaluator_{*param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId},
|
||||||
is_col_split_{info.data_split_mode == DataSplitMode::kCol} {
|
is_col_split_{info.IsColumnSplit()} {
|
||||||
interaction_constraints_.Configure(*param, info.num_col_);
|
interaction_constraints_.Configure(*param, info.num_col_);
|
||||||
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
|
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
|
||||||
param_->colsample_bynode, param_->colsample_bylevel,
|
param_->colsample_bynode, param_->colsample_bylevel,
|
||||||
|
|||||||
@ -72,12 +72,13 @@ class GloablApproxBuilder {
|
|||||||
} else {
|
} else {
|
||||||
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
||||||
}
|
}
|
||||||
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, p_fmat->IsColumnSplit());
|
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid,
|
||||||
|
p_fmat->Info().IsColumnSplit());
|
||||||
n_batches_++;
|
n_batches_++;
|
||||||
}
|
}
|
||||||
|
|
||||||
histogram_builder_.Reset(n_total_bins, BatchSpec(*param_, hess), ctx_->Threads(), n_batches_,
|
histogram_builder_.Reset(n_total_bins, BatchSpec(*param_, hess), ctx_->Threads(), n_batches_,
|
||||||
collective::IsDistributed(), p_fmat->IsColumnSplit());
|
collective::IsDistributed(), p_fmat->Info().IsColumnSplit());
|
||||||
monitor_->Stop(__func__);
|
monitor_->Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,7 +92,7 @@ class GloablApproxBuilder {
|
|||||||
for (auto const &g : gpair) {
|
for (auto const &g : gpair) {
|
||||||
root_sum.Add(g);
|
root_sum.Add(g);
|
||||||
}
|
}
|
||||||
if (p_fmat->IsRowSplit()) {
|
if (p_fmat->Info().IsRowSplit()) {
|
||||||
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2);
|
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2);
|
||||||
}
|
}
|
||||||
std::vector<CPUExpandEntry> nodes{best};
|
std::vector<CPUExpandEntry> nodes{best};
|
||||||
|
|||||||
@ -158,7 +158,7 @@ class MultiTargetHistBuilder {
|
|||||||
} else {
|
} else {
|
||||||
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
||||||
}
|
}
|
||||||
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->IsColumnSplit());
|
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->Info().IsColumnSplit());
|
||||||
page_id++;
|
page_id++;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -167,7 +167,7 @@ class MultiTargetHistBuilder {
|
|||||||
for (std::size_t i = 0; i < n_targets; ++i) {
|
for (std::size_t i = 0; i < n_targets; ++i) {
|
||||||
histogram_builder_.emplace_back();
|
histogram_builder_.emplace_back();
|
||||||
histogram_builder_.back().Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
histogram_builder_.back().Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
||||||
collective::IsDistributed(), p_fmat->IsColumnSplit());
|
collective::IsDistributed(), p_fmat->Info().IsColumnSplit());
|
||||||
}
|
}
|
||||||
|
|
||||||
evaluator_ = std::make_unique<HistMultiEvaluator>(ctx_, p_fmat->Info(), param_, col_sampler_);
|
evaluator_ = std::make_unique<HistMultiEvaluator>(ctx_, p_fmat->Info(), param_, col_sampler_);
|
||||||
@ -388,11 +388,12 @@ class HistBuilder {
|
|||||||
} else {
|
} else {
|
||||||
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
||||||
}
|
}
|
||||||
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, fmat->IsColumnSplit());
|
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid,
|
||||||
|
fmat->Info().IsColumnSplit());
|
||||||
++page_id;
|
++page_id;
|
||||||
}
|
}
|
||||||
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
||||||
collective::IsDistributed(), fmat->IsColumnSplit());
|
collective::IsDistributed(), fmat->Info().IsColumnSplit());
|
||||||
evaluator_ = std::make_unique<HistEvaluator<CPUExpandEntry>>(ctx_, this->param_, fmat->Info(),
|
evaluator_ = std::make_unique<HistEvaluator<CPUExpandEntry>>(ctx_, this->param_, fmat->Info(),
|
||||||
col_sampler_);
|
col_sampler_);
|
||||||
p_last_tree_ = p_tree;
|
p_last_tree_ = p_tree;
|
||||||
|
|||||||
@ -191,15 +191,9 @@ double GetMultiMetricEval(xgboost::Metric* metric,
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
bool IsNear(std::vector<xgboost::bst_float>::const_iterator _beg1,
|
|
||||||
std::vector<xgboost::bst_float>::const_iterator _end1,
|
float GetBaseScore(Json const &config) {
|
||||||
std::vector<xgboost::bst_float>::const_iterator _beg2) {
|
return std::stof(get<String const>(config["learner"]["learner_model_param"]["base_score"]));
|
||||||
for (auto iter1 = _beg1, iter2 = _beg2; iter1 != _end1; ++iter1, ++iter2) {
|
|
||||||
if (std::abs(*iter1 - *iter2) > xgboost::kRtEps){
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SimpleLCG::StateType SimpleLCG::operator()() {
|
SimpleLCG::StateType SimpleLCG::operator()() {
|
||||||
|
|||||||
@ -101,9 +101,8 @@ double GetMultiMetricEval(xgboost::Metric* metric,
|
|||||||
std::vector<xgboost::bst_uint> groups = {});
|
std::vector<xgboost::bst_uint> groups = {});
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
bool IsNear(std::vector<xgboost::bst_float>::const_iterator _beg1,
|
|
||||||
std::vector<xgboost::bst_float>::const_iterator _end1,
|
float GetBaseScore(Json const &config);
|
||||||
std::vector<xgboost::bst_float>::const_iterator _beg2);
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Linear congruential generator.
|
* \brief Linear congruential generator.
|
||||||
|
|||||||
@ -52,18 +52,33 @@ class BaseFederatedTest : public ::testing::Test {
|
|||||||
server_thread_->join();
|
server_thread_->join();
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitCommunicator(int rank) {
|
|
||||||
Json config{JsonObject()};
|
|
||||||
config["xgboost_communicator"] = String("federated");
|
|
||||||
config["federated_server_address"] = String(server_address_);
|
|
||||||
config["federated_world_size"] = kWorldSize;
|
|
||||||
config["federated_rank"] = rank;
|
|
||||||
xgboost::collective::Init(config);
|
|
||||||
}
|
|
||||||
|
|
||||||
static int const kWorldSize{3};
|
static int const kWorldSize{3};
|
||||||
std::string server_address_;
|
std::string server_address_;
|
||||||
std::unique_ptr<std::thread> server_thread_;
|
std::unique_ptr<std::thread> server_thread_;
|
||||||
std::unique_ptr<grpc::Server> server_;
|
std::unique_ptr<grpc::Server> server_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename Function, typename... Args>
|
||||||
|
void RunWithFederatedCommunicator(int32_t world_size, std::string const& server_address,
|
||||||
|
Function&& function, Args&&... args) {
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (auto rank = 0; rank < world_size; rank++) {
|
||||||
|
threads.emplace_back([&, rank]() {
|
||||||
|
Json config{JsonObject()};
|
||||||
|
config["xgboost_communicator"] = String("federated");
|
||||||
|
config["federated_server_address"] = String(server_address);
|
||||||
|
config["federated_world_size"] = world_size;
|
||||||
|
config["federated_rank"] = rank;
|
||||||
|
xgboost::collective::Init(config);
|
||||||
|
|
||||||
|
std::forward<Function>(function)(std::forward<Args>(args)...);
|
||||||
|
|
||||||
|
xgboost::collective::Finalize();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
for (auto& thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,12 +1,9 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2023 XGBoost contributors
|
* Copyright 2023 XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <dmlc/parameter.h>
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
|
|
||||||
#include <fstream>
|
|
||||||
#include <iostream>
|
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
|
||||||
#include "../../../plugin/federated/federated_server.h"
|
#include "../../../plugin/federated/federated_server.h"
|
||||||
@ -17,49 +14,40 @@
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
class FederatedDataTest : public BaseFederatedTest {
|
class FederatedDataTest : public BaseFederatedTest {};
|
||||||
public:
|
|
||||||
void VerifyLoadUri(int rank) {
|
|
||||||
InitCommunicator(rank);
|
|
||||||
|
|
||||||
size_t constexpr kRows{16};
|
void VerifyLoadUri() {
|
||||||
size_t const kCols = 8 + rank;
|
auto const rank = collective::GetRank();
|
||||||
|
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
size_t constexpr kRows{16};
|
||||||
std::string path = tmpdir.path + "/small" + std::to_string(rank) + ".csv";
|
size_t const kCols = 8 + rank;
|
||||||
CreateTestCSV(path, kRows, kCols);
|
|
||||||
|
|
||||||
std::unique_ptr<DMatrix> dmat;
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
std::string uri = path + "?format=csv";
|
std::string path = tmpdir.path + "/small" + std::to_string(rank) + ".csv";
|
||||||
dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kCol));
|
CreateTestCSV(path, kRows, kCols);
|
||||||
|
|
||||||
ASSERT_EQ(dmat->Info().num_col_, 8 * kWorldSize + 3);
|
std::unique_ptr<DMatrix> dmat;
|
||||||
ASSERT_EQ(dmat->Info().num_row_, kRows);
|
std::string uri = path + "?format=csv";
|
||||||
|
dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kCol));
|
||||||
|
|
||||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
ASSERT_EQ(dmat->Info().num_col_, 8 * collective::GetWorldSize() + 3);
|
||||||
auto entries = page.GetView().data;
|
ASSERT_EQ(dmat->Info().num_row_, kRows);
|
||||||
auto index = 0;
|
|
||||||
int offsets[] = {0, 8, 17};
|
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||||
int offset = offsets[rank];
|
auto entries = page.GetView().data;
|
||||||
for (auto row = 0; row < kRows; row++) {
|
auto index = 0;
|
||||||
for (auto col = 0; col < kCols; col++) {
|
int offsets[] = {0, 8, 17};
|
||||||
EXPECT_EQ(entries[index].index, col + offset);
|
int offset = offsets[rank];
|
||||||
index++;
|
for (auto row = 0; row < kRows; row++) {
|
||||||
}
|
for (auto col = 0; col < kCols; col++) {
|
||||||
|
EXPECT_EQ(entries[index].index, col + offset);
|
||||||
|
index++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
xgboost::collective::Finalize();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(FederatedDataTest, LoadUri) {
|
|
||||||
std::vector<std::thread> threads;
|
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
|
||||||
threads.emplace_back(&FederatedDataTest_LoadUri_Test::VerifyLoadUri, this, rank);
|
|
||||||
}
|
|
||||||
for (auto& thread : threads) {
|
|
||||||
thread.join();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedDataTest, LoadUri) {
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyLoadUri);
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
78
tests/cpp/plugin/test_federated_learner.cc
Normal file
78
tests/cpp/plugin/test_federated_learner.cc
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2023 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <dmlc/parameter.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/data.h>
|
||||||
|
#include <xgboost/objective.h>
|
||||||
|
|
||||||
|
#include "../../../plugin/federated/federated_server.h"
|
||||||
|
#include "../../../src/collective/communicator-inl.h"
|
||||||
|
#include "../helpers.h"
|
||||||
|
#include "helpers.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
|
||||||
|
class FederatedLearnerTest : public BaseFederatedTest {
|
||||||
|
protected:
|
||||||
|
static auto constexpr kRows{16};
|
||||||
|
static auto constexpr kCols{16};
|
||||||
|
};
|
||||||
|
|
||||||
|
void VerifyBaseScore(size_t rows, size_t cols, float expected_base_score) {
|
||||||
|
auto const world_size = collective::GetWorldSize();
|
||||||
|
auto const rank = collective::GetRank();
|
||||||
|
std::shared_ptr<DMatrix> Xy_{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)};
|
||||||
|
std::shared_ptr<DMatrix> sliced{Xy_->SliceCol(world_size, rank)};
|
||||||
|
std::unique_ptr<Learner> learner{Learner::Create({sliced})};
|
||||||
|
learner->SetParam("tree_method", "approx");
|
||||||
|
learner->SetParam("objective", "binary:logistic");
|
||||||
|
learner->UpdateOneIter(0, sliced);
|
||||||
|
Json config{Object{}};
|
||||||
|
learner->SaveConfig(&config);
|
||||||
|
auto base_score = GetBaseScore(config);
|
||||||
|
ASSERT_EQ(base_score, expected_base_score);
|
||||||
|
}
|
||||||
|
|
||||||
|
void VerifyModel(size_t rows, size_t cols, Json const& expected_model) {
|
||||||
|
auto const world_size = collective::GetWorldSize();
|
||||||
|
auto const rank = collective::GetRank();
|
||||||
|
std::shared_ptr<DMatrix> Xy_{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)};
|
||||||
|
std::shared_ptr<DMatrix> sliced{Xy_->SliceCol(world_size, rank)};
|
||||||
|
std::unique_ptr<Learner> learner{Learner::Create({sliced})};
|
||||||
|
learner->SetParam("tree_method", "approx");
|
||||||
|
learner->SetParam("objective", "binary:logistic");
|
||||||
|
learner->UpdateOneIter(0, sliced);
|
||||||
|
Json model{Object{}};
|
||||||
|
learner->SaveModel(&model);
|
||||||
|
ASSERT_EQ(model, expected_model);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedLearnerTest, BaseScore) {
|
||||||
|
std::shared_ptr<DMatrix> Xy_{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)};
|
||||||
|
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
||||||
|
learner->SetParam("tree_method", "approx");
|
||||||
|
learner->SetParam("objective", "binary:logistic");
|
||||||
|
learner->UpdateOneIter(0, Xy_);
|
||||||
|
Json config{Object{}};
|
||||||
|
learner->SaveConfig(&config);
|
||||||
|
auto base_score = GetBaseScore(config);
|
||||||
|
ASSERT_NE(base_score, ObjFunction::DefaultBaseScore());
|
||||||
|
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyBaseScore, kRows, kCols,
|
||||||
|
base_score);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedLearnerTest, Model) {
|
||||||
|
std::shared_ptr<DMatrix> Xy_{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)};
|
||||||
|
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
||||||
|
learner->SetParam("tree_method", "approx");
|
||||||
|
learner->SetParam("objective", "binary:logistic");
|
||||||
|
learner->UpdateOneIter(0, Xy_);
|
||||||
|
Json model{Object{}};
|
||||||
|
learner->SaveModel(&model);
|
||||||
|
|
||||||
|
RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyModel, kRows, kCols,
|
||||||
|
std::cref(model));
|
||||||
|
}
|
||||||
|
} // namespace xgboost
|
||||||
@ -460,10 +460,6 @@ class InitBaseScore : public ::testing::Test {
|
|||||||
|
|
||||||
void SetUp() override { Xy_ = RandomDataGenerator{10, Cols(), 0}.GenerateDMatrix(true); }
|
void SetUp() override { Xy_ = RandomDataGenerator{10, Cols(), 0}.GenerateDMatrix(true); }
|
||||||
|
|
||||||
static float GetBaseScore(Json const &config) {
|
|
||||||
return std::stof(get<String const>(config["learner"]["learner_model_param"]["base_score"]));
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void TestUpdateConfig() {
|
void TestUpdateConfig() {
|
||||||
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
||||||
@ -611,4 +607,32 @@ TEST_F(InitBaseScore, InitAfterLoad) { this->TestInitAfterLoad(); }
|
|||||||
TEST_F(InitBaseScore, InitWithPredict) { this->TestInitWithPredt(); }
|
TEST_F(InitBaseScore, InitWithPredict) { this->TestInitWithPredt(); }
|
||||||
|
|
||||||
TEST_F(InitBaseScore, UpdateProcess) { this->TestUpdateProcess(); }
|
TEST_F(InitBaseScore, UpdateProcess) { this->TestUpdateProcess(); }
|
||||||
|
|
||||||
|
void TestColumnSplitBaseScore(std::shared_ptr<DMatrix> Xy_, float expected_base_score) {
|
||||||
|
auto const world_size = collective::GetWorldSize();
|
||||||
|
auto const rank = collective::GetRank();
|
||||||
|
std::shared_ptr<DMatrix> sliced{Xy_->SliceCol(world_size, rank)};
|
||||||
|
std::unique_ptr<Learner> learner{Learner::Create({sliced})};
|
||||||
|
learner->SetParam("tree_method", "approx");
|
||||||
|
learner->SetParam("objective", "binary:logistic");
|
||||||
|
learner->UpdateOneIter(0, sliced);
|
||||||
|
Json config{Object{}};
|
||||||
|
learner->SaveConfig(&config);
|
||||||
|
auto base_score = GetBaseScore(config);
|
||||||
|
ASSERT_EQ(base_score, expected_base_score);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(InitBaseScore, ColumnSplit) {
|
||||||
|
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
||||||
|
learner->SetParam("tree_method", "approx");
|
||||||
|
learner->SetParam("objective", "binary:logistic");
|
||||||
|
learner->UpdateOneIter(0, Xy_);
|
||||||
|
Json config{Object{}};
|
||||||
|
learner->SaveConfig(&config);
|
||||||
|
auto base_score = GetBaseScore(config);
|
||||||
|
ASSERT_NE(base_score, ObjFunction::DefaultBaseScore());
|
||||||
|
|
||||||
|
auto constexpr kWorldSize{3};
|
||||||
|
RunWithInMemoryCommunicator(kWorldSize, &TestColumnSplitBaseScore, Xy_, base_score);
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -6,11 +6,12 @@
|
|||||||
|
|
||||||
#include "../../src/common/linalg_op.h"
|
#include "../../src/common/linalg_op.h"
|
||||||
#include "../../src/tree/fit_stump.h"
|
#include "../../src/tree/fit_stump.h"
|
||||||
|
#include "../helpers.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
namespace {
|
namespace {
|
||||||
void TestFitStump(Context const *ctx) {
|
void TestFitStump(Context const *ctx, DataSplitMode split = DataSplitMode::kRow) {
|
||||||
std::size_t constexpr kRows = 16, kTargets = 2;
|
std::size_t constexpr kRows = 16, kTargets = 2;
|
||||||
HostDeviceVector<GradientPair> gpair;
|
HostDeviceVector<GradientPair> gpair;
|
||||||
auto &h_gpair = gpair.HostVector();
|
auto &h_gpair = gpair.HostVector();
|
||||||
@ -22,6 +23,7 @@ void TestFitStump(Context const *ctx) {
|
|||||||
}
|
}
|
||||||
linalg::Vector<float> out;
|
linalg::Vector<float> out;
|
||||||
MetaInfo info;
|
MetaInfo info;
|
||||||
|
info.data_split_mode = split;
|
||||||
FitStump(ctx, info, gpair, kTargets, &out);
|
FitStump(ctx, info, gpair, kTargets, &out);
|
||||||
auto h_out = out.HostView();
|
auto h_out = out.HostView();
|
||||||
for (auto it = linalg::cbegin(h_out); it != linalg::cend(h_out); ++it) {
|
for (auto it = linalg::cbegin(h_out); it != linalg::cend(h_out); ++it) {
|
||||||
@ -45,5 +47,12 @@ TEST(InitEstimation, GPUFitStump) {
|
|||||||
TestFitStump(&ctx);
|
TestFitStump(&ctx);
|
||||||
}
|
}
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
|
|
||||||
|
TEST(InitEstimation, FitStumpColumnSplit) {
|
||||||
|
Context ctx;
|
||||||
|
auto constexpr kWorldSize{3};
|
||||||
|
RunWithInMemoryCommunicator(kWorldSize, &TestFitStump, &ctx, DataSplitMode::kCol);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user