More tests for column split and vertical federated learning (#8985)

Added some more tests for the learner and fit_stump, for both column-wise distributed learning and vertical federated learning.

Also moved the `IsRowSplit` and `IsColumnSplit` methods from the `DMatrix` to the `MetaInfo` since in some places we only have access to the `MetaInfo`. Added a new convenience method `IsVerticalFederatedLearning`.

Some refactoring of the testing fixtures.
This commit is contained in:
Rong Ou
2023-03-28 01:40:26 -07:00
committed by GitHub
parent 401ce5cf5e
commit ff26cd3212
18 changed files with 212 additions and 94 deletions

View File

@@ -46,7 +46,7 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
if (!use_sorted) {
HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info),
m->IsColumnSplit(), n_threads);
m->Info().IsColumnSplit(), n_threads);
for (auto const& page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian);
}
@@ -54,7 +54,7 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
} else {
SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info),
m->IsColumnSplit(), n_threads};
m->Info().IsColumnSplit(), n_threads};
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
container.PushColPage(page, info, hessian);
}

View File

@@ -704,7 +704,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
}
void MetaInfo::SynchronizeNumberOfColumns() {
if (collective::IsFederated() && data_split_mode == DataSplitMode::kCol) {
if (IsVerticalFederated()) {
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
} else {
collective::Allreduce<collective::Operation::kMax>(&num_col_, 1);
@@ -770,6 +770,10 @@ void MetaInfo::Validate(std::int32_t device) const {
void MetaInfo::SetInfoFromCUDA(Context const&, StringView, Json) { common::AssertGPUSupport(); }
#endif // !defined(XGBOOST_USE_CUDA)
bool MetaInfo::IsVerticalFederated() const {
return collective::IsFederated() && IsColumnSplit();
}
using DMatrixThreadLocal =
dmlc::ThreadLocalStore<std::map<DMatrix const *, XGBAPIThreadLocalEntry>>;

View File

@@ -213,7 +213,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
SyncFeatureType(&h_ft);
p_sketch.reset(new common::HostSketchContainer{
batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(),
proxy->IsColumnSplit(), ctx_.Threads()});
proxy->Info().IsColumnSplit(), ctx_.Threads()});
}
HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = batch_nnz[i];

View File

@@ -74,7 +74,7 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
}
void SimpleDMatrix::ReindexFeatures() {
if (collective::IsFederated() && info_.data_split_mode == DataSplitMode::kCol) {
if (info_.IsVerticalFederated()) {
std::vector<uint64_t> buffer(collective::GetWorldSize());
buffer[collective::GetRank()] = info_.num_col_;
collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t));

View File

@@ -860,9 +860,9 @@ class LearnerConfiguration : public Learner {
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
// Special handling for vertical federated learning.
if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the estimation is calculated there
// and added to other workers.
// and broadcast to other workers.
if (collective::GetRank() == 0) {
UsePtr(obj_)->InitEstimation(info, base_score);
collective::Broadcast(base_score->Data()->HostPointer(),
@@ -1487,7 +1487,7 @@ class LearnerImpl : public LearnerIO {
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
HostDeviceVector<GradientPair>* out_gpair) {
// Special handling for vertical federated learning.
if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the gradients are calculated there
// and broadcast to other workers.
if (collective::GetRank() == 0) {

View File

@@ -605,7 +605,7 @@ class CPUPredictor : public Predictor {
protected:
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const {
if (p_fmat->IsColumnSplit()) {
if (p_fmat->Info().IsColumnSplit()) {
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
helper.PredictDMatrix(p_fmat, out_preds);
return;

View File

@@ -45,8 +45,7 @@ void FitStump(Context const* ctx, MetaInfo const& info,
}
CHECK(h_sum.CContiguous());
// In vertical federated learning, only worker 0 needs to call this, no need to do an allreduce.
if (!collective::IsFederated() || info.data_split_mode != DataSplitMode::kCol) {
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
}

View File

@@ -449,7 +449,7 @@ class HistEvaluator {
param_{param},
column_sampler_{std::move(sampler)},
tree_evaluator_{*param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId},
is_col_split_{info.data_split_mode == DataSplitMode::kCol} {
is_col_split_{info.IsColumnSplit()} {
interaction_constraints_.Configure(*param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_->colsample_bynode, param_->colsample_bylevel,

View File

@@ -72,12 +72,13 @@ class GloablApproxBuilder {
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, p_fmat->IsColumnSplit());
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid,
p_fmat->Info().IsColumnSplit());
n_batches_++;
}
histogram_builder_.Reset(n_total_bins, BatchSpec(*param_, hess), ctx_->Threads(), n_batches_,
collective::IsDistributed(), p_fmat->IsColumnSplit());
collective::IsDistributed(), p_fmat->Info().IsColumnSplit());
monitor_->Stop(__func__);
}
@@ -91,7 +92,7 @@ class GloablApproxBuilder {
for (auto const &g : gpair) {
root_sum.Add(g);
}
if (p_fmat->IsRowSplit()) {
if (p_fmat->Info().IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2);
}
std::vector<CPUExpandEntry> nodes{best};

View File

@@ -158,7 +158,7 @@ class MultiTargetHistBuilder {
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->IsColumnSplit());
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->Info().IsColumnSplit());
page_id++;
}
@@ -167,7 +167,7 @@ class MultiTargetHistBuilder {
for (std::size_t i = 0; i < n_targets; ++i) {
histogram_builder_.emplace_back();
histogram_builder_.back().Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
collective::IsDistributed(), p_fmat->IsColumnSplit());
collective::IsDistributed(), p_fmat->Info().IsColumnSplit());
}
evaluator_ = std::make_unique<HistMultiEvaluator>(ctx_, p_fmat->Info(), param_, col_sampler_);
@@ -388,11 +388,12 @@ class HistBuilder {
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, fmat->IsColumnSplit());
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid,
fmat->Info().IsColumnSplit());
++page_id;
}
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
collective::IsDistributed(), fmat->IsColumnSplit());
collective::IsDistributed(), fmat->Info().IsColumnSplit());
evaluator_ = std::make_unique<HistEvaluator<CPUExpandEntry>>(ctx_, this->param_, fmat->Info(),
col_sampler_);
p_last_tree_ = p_tree;