Make sure metrics work with federated learning (#9037)

This commit is contained in:
Rong Ou
2023-04-19 00:39:11 -07:00
committed by GitHub
parent ef13dd31b1
commit 42d100de18
11 changed files with 451 additions and 152 deletions

View File

@@ -34,6 +34,7 @@
#include <utility> // for pair, as_const, move, swap
#include <vector> // for vector
#include "collective/aggregator.h" // for ApplyWithLabels
#include "collective/communicator-inl.h" // for Allreduce, Broadcast, GetRank, IsDistributed
#include "collective/communicator.h" // for Operation
#include "common/api_entry.h" // for XGBAPIThreadLocalEntry
@@ -859,22 +860,10 @@ class LearnerConfiguration : public Learner {
}
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
// Special handling for vertical federated learning.
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the estimation is calculated there
// and broadcast to other workers.
if (collective::GetRank() == 0) {
UsePtr(obj_)->InitEstimation(info, base_score);
collective::Broadcast(base_score->Data()->HostPointer(),
sizeof(bst_float) * base_score->Size(), 0);
} else {
base_score->Reshape(1);
collective::Broadcast(base_score->Data()->HostPointer(),
sizeof(bst_float) * base_score->Size(), 0);
}
} else {
UsePtr(obj_)->InitEstimation(info, base_score);
}
base_score->Reshape(1);
collective::ApplyWithLabels(info, base_score->Data()->HostPointer(),
sizeof(bst_float) * base_score->Size(),
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
}
};
@@ -1486,24 +1475,10 @@ class LearnerImpl : public LearnerIO {
private:
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
HostDeviceVector<GradientPair>* out_gpair) {
// Special handling for vertical federated learning.
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the gradients are calculated there
// and broadcast to other workers.
if (collective::GetRank() == 0) {
obj_->GetGradient(preds, info, iteration, out_gpair);
collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair),
0);
} else {
CHECK_EQ(info.labels.Size(), 0)
<< "In vertical federated learning, labels should only be on the first worker";
out_gpair->Resize(preds.Size());
collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair),
0);
}
} else {
obj_->GetGradient(preds, info, iteration, out_gpair);
}
out_gpair->Resize(preds.Size());
collective::ApplyWithLabels(info, out_gpair->HostPointer(),
out_gpair->Size() * sizeof(GradientPair),
[&] { obj_->GetGradient(preds, info, iteration, out_gpair); });
}
/*! \brief random number transformation seed. */