[coll] Add federated coll. (#9738)
- Define a new data type, the proto file is copied for now. - Merge client and communicator into `FederatedColl`. - Define CUDA variant. - Migrate tests for CPU, add tests for CUDA.
This commit is contained in:
@@ -9,7 +9,8 @@
|
||||
#include <thread> // for thread
|
||||
|
||||
#include "../../../../plugin/federated/federated_tracker.h"
|
||||
#include "xgboost/json.h" // for Json
|
||||
#include "federated_comm.h" // for FederatedComm
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
template <typename WorkerFn>
|
||||
@@ -28,7 +29,15 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] { fn(port, i); });
|
||||
workers.emplace_back([=] {
|
||||
Json config{Object{}};
|
||||
config["federated_world_size"] = n_workers;
|
||||
config["federated_rank"] = i;
|
||||
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
|
||||
auto comm = std::make_shared<FederatedComm>(config);
|
||||
|
||||
fn(comm, i);
|
||||
});
|
||||
}
|
||||
|
||||
for (auto& t : workers) {
|
||||
|
||||
Reference in New Issue
Block a user