More tests for column split and vertical federated learning (#8985)

Added some more tests for the learner and fit_stump, for both column-wise distributed learning and vertical federated learning.

Also moved the `IsRowSplit` and `IsColumnSplit` methods from the `DMatrix` to the `MetaInfo` since in some places we only have access to the `MetaInfo`. Added a new convenience method `IsVerticalFederatedLearning`.

Some refactoring of the testing fixtures.
This commit is contained in:
Rong Ou
2023-03-28 01:40:26 -07:00
committed by GitHub
parent 401ce5cf5e
commit ff26cd3212
18 changed files with 212 additions and 94 deletions

View File

@@ -45,8 +45,7 @@ void FitStump(Context const* ctx, MetaInfo const& info,
}
CHECK(h_sum.CContiguous());
// In vertical federated learning, only worker 0 needs to call this, no need to do an allreduce.
if (!collective::IsFederated() || info.data_split_mode != DataSplitMode::kCol) {
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
}

View File

@@ -449,7 +449,7 @@ class HistEvaluator {
param_{param},
column_sampler_{std::move(sampler)},
tree_evaluator_{*param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId},
is_col_split_{info.data_split_mode == DataSplitMode::kCol} {
is_col_split_{info.IsColumnSplit()} {
interaction_constraints_.Configure(*param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_->colsample_bynode, param_->colsample_bylevel,

View File

@@ -72,12 +72,13 @@ class GloablApproxBuilder {
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, p_fmat->IsColumnSplit());
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid,
p_fmat->Info().IsColumnSplit());
n_batches_++;
}
histogram_builder_.Reset(n_total_bins, BatchSpec(*param_, hess), ctx_->Threads(), n_batches_,
collective::IsDistributed(), p_fmat->IsColumnSplit());
collective::IsDistributed(), p_fmat->Info().IsColumnSplit());
monitor_->Stop(__func__);
}
@@ -91,7 +92,7 @@ class GloablApproxBuilder {
for (auto const &g : gpair) {
root_sum.Add(g);
}
if (p_fmat->IsRowSplit()) {
if (p_fmat->Info().IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2);
}
std::vector<CPUExpandEntry> nodes{best};

View File

@@ -158,7 +158,7 @@ class MultiTargetHistBuilder {
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->IsColumnSplit());
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->Info().IsColumnSplit());
page_id++;
}
@@ -167,7 +167,7 @@ class MultiTargetHistBuilder {
for (std::size_t i = 0; i < n_targets; ++i) {
histogram_builder_.emplace_back();
histogram_builder_.back().Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
collective::IsDistributed(), p_fmat->IsColumnSplit());
collective::IsDistributed(), p_fmat->Info().IsColumnSplit());
}
evaluator_ = std::make_unique<HistMultiEvaluator>(ctx_, p_fmat->Info(), param_, col_sampler_);
@@ -388,11 +388,12 @@ class HistBuilder {
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, fmat->IsColumnSplit());
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid,
fmat->Info().IsColumnSplit());
++page_id;
}
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
collective::IsDistributed(), fmat->IsColumnSplit());
collective::IsDistributed(), fmat->Info().IsColumnSplit());
evaluator_ = std::make_unique<HistEvaluator<CPUExpandEntry>>(ctx_, this->param_, fmat->Info(),
col_sampler_);
p_last_tree_ = p_tree;