[coll] Add federated coll. (#9738)
- Define a new data type, the proto file is copied for now. - Merge client and communicator into `FederatedColl`. - Define CUDA variant. - Migrate tests for CPU, add tests for CUDA.
This commit is contained in:
@@ -7,10 +7,10 @@
|
||||
#include <thread> // for thread
|
||||
|
||||
#include "../../../../plugin/federated/federated_comm.h"
|
||||
#include "../../collective/net_test.h" // for SocketTest
|
||||
#include "../../helpers.h" // for ExpectThrow
|
||||
#include "test_worker.h" // for TestFederated
|
||||
#include "xgboost/json.h" // for Json
|
||||
#include "../../collective/test_worker.h" // for SocketTest
|
||||
#include "../../helpers.h" // for ExpectThrow
|
||||
#include "test_worker.h" // for TestFederated
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
@@ -71,14 +71,9 @@ TEST_F(FederatedCommTest, IsDistributed) {
|
||||
|
||||
TEST_F(FederatedCommTest, InsecureTracker) {
|
||||
std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u);
|
||||
TestFederated(n_workers, [=](std::int32_t port, std::int32_t rank) {
|
||||
Json config{Object{}};
|
||||
config["federated_world_size"] = n_workers;
|
||||
config["federated_rank"] = rank;
|
||||
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
|
||||
FederatedComm comm{config};
|
||||
ASSERT_EQ(comm.Rank(), rank);
|
||||
ASSERT_EQ(comm.World(), n_workers);
|
||||
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
|
||||
ASSERT_EQ(comm->Rank(), rank);
|
||||
ASSERT_EQ(comm->World(), n_workers);
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
Reference in New Issue
Block a user