Support vertical federated learning (#8932)

This commit is contained in:
Rong Ou
2023-03-21 23:25:26 -07:00
committed by GitHub
parent 8dc1e4b3ea
commit b240f055d3
23 changed files with 371 additions and 249 deletions

View File

@@ -440,7 +440,7 @@ class LearnerConfiguration : public Learner {
info.Validate(Ctx()->gpu_id);
// We estimate it from input data.
linalg::Tensor<float, 1> base_score;
UsePtr(obj_)->InitEstimation(info, &base_score);
InitEstimation(info, &base_score);
CHECK_EQ(base_score.Size(), 1);
mparam_.base_score = base_score(0);
CHECK(!std::isnan(mparam_.base_score));
@@ -857,6 +857,25 @@ class LearnerConfiguration : public Learner {
mparam_.num_target = n_targets;
}
}
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
// Special handling for vertical federated learning.
if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) {
// We assume labels are only available on worker 0, so the estimation is calculated there
// and added to other workers.
if (collective::GetRank() == 0) {
UsePtr(obj_)->InitEstimation(info, base_score);
collective::Broadcast(base_score->Data()->HostPointer(),
sizeof(bst_float) * base_score->Size(), 0);
} else {
base_score->Reshape(1);
collective::Broadcast(base_score->Data()->HostPointer(),
sizeof(bst_float) * base_score->Size(), 0);
}
} else {
UsePtr(obj_)->InitEstimation(info, base_score);
}
}
};
std::string const LearnerConfiguration::kEvalMetric {"eval_metric"}; // NOLINT
@@ -1307,7 +1326,7 @@ class LearnerImpl : public LearnerIO {
monitor_.Stop("PredictRaw");
monitor_.Start("GetGradient");
obj_->GetGradient(predt.predictions, train->Info(), iter, &gpair_);
GetGradient(predt.predictions, train->Info(), iter, &gpair_);
monitor_.Stop("GetGradient");
TrainingObserver::Instance().Observe(gpair_, "Gradients");
@@ -1486,6 +1505,28 @@ class LearnerImpl : public LearnerIO {
}
private:
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
HostDeviceVector<GradientPair>* out_gpair) {
// Special handling for vertical federated learning.
if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) {
// We assume labels are only available on worker 0, so the gradients are calculated there
// and broadcast to other workers.
if (collective::GetRank() == 0) {
obj_->GetGradient(preds, info, iteration, out_gpair);
collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair),
0);
} else {
CHECK_EQ(info.labels.Size(), 0)
<< "In vertical federated learning, labels should only be on the first worker";
out_gpair->Resize(preds.Size());
collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair),
0);
}
} else {
obj_->GetGradient(preds, info, iteration, out_gpair);
}
}
/*! \brief random number transformation seed. */
static int32_t constexpr kRandSeedMagic = 127;
// gradient pairs