More tests for column split and vertical federated learning (#8985)
Added some more tests for the learner and fit_stump, for both column-wise distributed learning and vertical federated learning. Also moved the `IsRowSplit` and `IsColumnSplit` methods from the `DMatrix` to the `MetaInfo` since in some places we only have access to the `MetaInfo`. Added a new convenience method `IsVerticalFederatedLearning`. Some refactoring of the testing fixtures.
This commit is contained in:
@@ -45,8 +45,7 @@ void FitStump(Context const* ctx, MetaInfo const& info,
|
||||
}
|
||||
CHECK(h_sum.CContiguous());
|
||||
|
||||
// In vertical federated learning, only worker 0 needs to call this, no need to do an allreduce.
|
||||
if (!collective::IsFederated() || info.data_split_mode != DataSplitMode::kCol) {
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(
|
||||
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
|
||||
}
|
||||
|
||||
@@ -449,7 +449,7 @@ class HistEvaluator {
|
||||
param_{param},
|
||||
column_sampler_{std::move(sampler)},
|
||||
tree_evaluator_{*param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId},
|
||||
is_col_split_{info.data_split_mode == DataSplitMode::kCol} {
|
||||
is_col_split_{info.IsColumnSplit()} {
|
||||
interaction_constraints_.Configure(*param, info.num_col_);
|
||||
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
|
||||
param_->colsample_bynode, param_->colsample_bylevel,
|
||||
|
||||
@@ -72,12 +72,13 @@ class GloablApproxBuilder {
|
||||
} else {
|
||||
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
||||
}
|
||||
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, p_fmat->IsColumnSplit());
|
||||
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid,
|
||||
p_fmat->Info().IsColumnSplit());
|
||||
n_batches_++;
|
||||
}
|
||||
|
||||
histogram_builder_.Reset(n_total_bins, BatchSpec(*param_, hess), ctx_->Threads(), n_batches_,
|
||||
collective::IsDistributed(), p_fmat->IsColumnSplit());
|
||||
collective::IsDistributed(), p_fmat->Info().IsColumnSplit());
|
||||
monitor_->Stop(__func__);
|
||||
}
|
||||
|
||||
@@ -91,7 +92,7 @@ class GloablApproxBuilder {
|
||||
for (auto const &g : gpair) {
|
||||
root_sum.Add(g);
|
||||
}
|
||||
if (p_fmat->IsRowSplit()) {
|
||||
if (p_fmat->Info().IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2);
|
||||
}
|
||||
std::vector<CPUExpandEntry> nodes{best};
|
||||
|
||||
@@ -158,7 +158,7 @@ class MultiTargetHistBuilder {
|
||||
} else {
|
||||
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
||||
}
|
||||
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->IsColumnSplit());
|
||||
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->Info().IsColumnSplit());
|
||||
page_id++;
|
||||
}
|
||||
|
||||
@@ -167,7 +167,7 @@ class MultiTargetHistBuilder {
|
||||
for (std::size_t i = 0; i < n_targets; ++i) {
|
||||
histogram_builder_.emplace_back();
|
||||
histogram_builder_.back().Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
||||
collective::IsDistributed(), p_fmat->IsColumnSplit());
|
||||
collective::IsDistributed(), p_fmat->Info().IsColumnSplit());
|
||||
}
|
||||
|
||||
evaluator_ = std::make_unique<HistMultiEvaluator>(ctx_, p_fmat->Info(), param_, col_sampler_);
|
||||
@@ -388,11 +388,12 @@ class HistBuilder {
|
||||
} else {
|
||||
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
||||
}
|
||||
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, fmat->IsColumnSplit());
|
||||
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid,
|
||||
fmat->Info().IsColumnSplit());
|
||||
++page_id;
|
||||
}
|
||||
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
||||
collective::IsDistributed(), fmat->IsColumnSplit());
|
||||
collective::IsDistributed(), fmat->Info().IsColumnSplit());
|
||||
evaluator_ = std::make_unique<HistEvaluator<CPUExpandEntry>>(ctx_, this->param_, fmat->Info(),
|
||||
col_sampler_);
|
||||
p_last_tree_ = p_tree;
|
||||
|
||||
Reference in New Issue
Block a user