Revamp the rabit implementation. (#10112)
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features: - Federated learning for both CPU and GPU. - NCCL. - More data types. - A unified interface for all the underlying implementations. - Improved timeout handling for both tracker and workers. - Exhausted tests with metrics (fixed a couple of bugs along the way). - A reusable tracker for Python and JVM packages.
This commit is contained in:
@@ -108,6 +108,32 @@ TEST_F(FederatedCollTestGPU, Allreduce) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST(FederatedCollGPUGlobal, Allreduce) {
|
||||
std::int32_t n_workers = common::AllVisibleGPUs();
|
||||
TestFederatedGlobal(n_workers, [&] {
|
||||
auto r = collective::GetRank();
|
||||
auto world = collective::GetWorldSize();
|
||||
CHECK_EQ(n_workers, world);
|
||||
|
||||
dh::device_vector<std::uint32_t> values(3, r);
|
||||
auto ctx = MakeCUDACtx(r);
|
||||
auto rc = collective::Allreduce(
|
||||
&ctx, linalg::MakeVec(values.data().get(), values.size(), DeviceOrd::CUDA(r)),
|
||||
Op::kBitwiseOR);
|
||||
SafeColl(rc);
|
||||
|
||||
std::vector<std::uint32_t> expected(values.size(), 0);
|
||||
for (std::int32_t rank = 0; rank < world; ++rank) {
|
||||
for (std::size_t i = 0; i < expected.size(); ++i) {
|
||||
expected[i] |= rank;
|
||||
}
|
||||
}
|
||||
for (std::size_t i = 0; i < expected.size(); ++i) {
|
||||
CHECK_EQ(expected[i], values[i]);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(FederatedCollTestGPU, Broadcast) {
|
||||
std::int32_t n_workers = common::AllVisibleGPUs();
|
||||
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
|
||||
|
||||
@@ -11,12 +11,24 @@
|
||||
|
||||
#include "../../../../plugin/federated/federated_tracker.h"
|
||||
#include "../../../../src/collective/comm_group.h"
|
||||
#include "../../../../src/collective/communicator-inl.h"
|
||||
#include "federated_comm.h" // for FederatedComm
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
inline Json FederatedTestConfig(std::int32_t n_workers, std::int32_t port, std::int32_t i) {
|
||||
Json config{Object{}};
|
||||
config["dmlc_communicator"] = std::string{"federated"};
|
||||
config["dmlc_task_id"] = std::to_string(i);
|
||||
config["dmlc_retry"] = 2;
|
||||
config["federated_world_size"] = n_workers;
|
||||
config["federated_rank"] = i;
|
||||
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
|
||||
return config;
|
||||
}
|
||||
|
||||
template <typename WorkerFn>
|
||||
void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
void TestFederatedImpl(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
Json config{Object()};
|
||||
config["federated_secure"] = Boolean{false};
|
||||
config["n_workers"] = Integer{n_workers};
|
||||
@@ -30,16 +42,7 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] {
|
||||
Json config{Object{}};
|
||||
config["federated_world_size"] = n_workers;
|
||||
config["federated_rank"] = i;
|
||||
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
|
||||
auto comm = std::make_shared<FederatedComm>(
|
||||
DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, std::to_string(i), config);
|
||||
|
||||
fn(comm, i);
|
||||
});
|
||||
workers.emplace_back([=] { fn(port, i); });
|
||||
}
|
||||
|
||||
for (auto& t : workers) {
|
||||
@@ -51,39 +54,33 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
}
|
||||
|
||||
template <typename WorkerFn>
|
||||
void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
TestFederatedImpl(n_workers, [&](std::int32_t port, std::int32_t i) {
|
||||
auto config = FederatedTestConfig(n_workers, port, i);
|
||||
auto comm = std::make_shared<FederatedComm>(
|
||||
DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, std::to_string(i), config);
|
||||
|
||||
fn(comm, i);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename WorkerFn>
|
||||
void TestFederatedGroup(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
Json config{Object()};
|
||||
config["federated_secure"] = Boolean{false};
|
||||
config["n_workers"] = Integer{n_workers};
|
||||
FederatedTracker tracker{config};
|
||||
auto fut = tracker.Run();
|
||||
TestFederatedImpl(n_workers, [&](std::int32_t port, std::int32_t i) {
|
||||
auto config = FederatedTestConfig(n_workers, port, i);
|
||||
std::shared_ptr<CommGroup> comm_group{CommGroup::Create(config)};
|
||||
fn(comm_group, i);
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
auto rc = tracker.WaitUntilReady();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] {
|
||||
Json config{Object{}};
|
||||
config["dmlc_communicator"] = std::string{"federated"};
|
||||
config["dmlc_task_id"] = std::to_string(i);
|
||||
config["dmlc_retry"] = 2;
|
||||
config["federated_world_size"] = n_workers;
|
||||
config["federated_rank"] = i;
|
||||
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
|
||||
std::shared_ptr<CommGroup> comm_group{CommGroup::Create(config)};
|
||||
fn(comm_group, i);
|
||||
});
|
||||
}
|
||||
|
||||
for (auto& t : workers) {
|
||||
t.join();
|
||||
}
|
||||
|
||||
rc = tracker.Shutdown();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
template <typename WorkerFn>
|
||||
void TestFederatedGlobal(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
TestFederatedImpl(n_workers, [&](std::int32_t port, std::int32_t i) {
|
||||
auto config = FederatedTestConfig(n_workers, port, i);
|
||||
collective::Init(config);
|
||||
fn();
|
||||
collective::Finalize();
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
Reference in New Issue
Block a user