[coll] Add comm group. (#9759)
- Implement `CommGroup` for double dispatching. - Small cleanup to tracker for handling abort.
This commit is contained in:
parent
c3a0622b49
commit
6c0a190f6d
@ -60,7 +60,8 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FederatedComm::FederatedComm(Json const& config) {
|
FederatedComm::FederatedComm(std::int32_t retry, std::chrono::seconds timeout, std::string task_id,
|
||||||
|
Json const& config) {
|
||||||
/**
|
/**
|
||||||
* Topology
|
* Topology
|
||||||
*/
|
*/
|
||||||
@ -93,6 +94,13 @@ FederatedComm::FederatedComm(Json const& config) {
|
|||||||
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
|
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
|
||||||
CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required.";
|
CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required.";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Basic config
|
||||||
|
*/
|
||||||
|
this->retry_ = retry;
|
||||||
|
this->timeout_ = timeout;
|
||||||
|
this->task_id_ = task_id;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Certificates
|
* Certificates
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -11,6 +11,8 @@ namespace xgboost::collective {
|
|||||||
CUDAFederatedComm::CUDAFederatedComm(Context const* ctx, std::shared_ptr<FederatedComm const> impl)
|
CUDAFederatedComm::CUDAFederatedComm(Context const* ctx, std::shared_ptr<FederatedComm const> impl)
|
||||||
: FederatedComm{impl}, stream_{ctx->CUDACtx()->Stream()} {
|
: FederatedComm{impl}, stream_{ctx->CUDACtx()->Stream()} {
|
||||||
CHECK(impl);
|
CHECK(impl);
|
||||||
|
CHECK(ctx->IsCUDA());
|
||||||
|
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
|
||||||
}
|
}
|
||||||
|
|
||||||
Comm* FederatedComm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll>) const {
|
Comm* FederatedComm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll>) const {
|
||||||
|
|||||||
@ -27,6 +27,10 @@ class FederatedComm : public Comm {
|
|||||||
this->rank_ = that->Rank();
|
this->rank_ = that->Rank();
|
||||||
this->world_ = that->World();
|
this->world_ = that->World();
|
||||||
|
|
||||||
|
this->retry_ = that->Retry();
|
||||||
|
this->timeout_ = that->Timeout();
|
||||||
|
this->task_id_ = that->TaskID();
|
||||||
|
|
||||||
this->tracker_ = that->TrackerInfo();
|
this->tracker_ = that->TrackerInfo();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,7 +45,8 @@ class FederatedComm : public Comm {
|
|||||||
* - federated_client_key_path
|
* - federated_client_key_path
|
||||||
* - federated_client_cert_path
|
* - federated_client_cert_path
|
||||||
*/
|
*/
|
||||||
explicit FederatedComm(Json const& config);
|
explicit FederatedComm(std::int32_t retry, std::chrono::seconds timeout, std::string task_id,
|
||||||
|
Json const& config);
|
||||||
explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world,
|
explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world,
|
||||||
std::int32_t rank) {
|
std::int32_t rank) {
|
||||||
this->Init(host, port, world, rank, {}, {}, {});
|
this->Init(host, port, world, rank, {}, {}, {});
|
||||||
|
|||||||
@ -5,13 +5,17 @@
|
|||||||
|
|
||||||
#include <algorithm> // for copy
|
#include <algorithm> // for copy
|
||||||
#include <chrono> // for seconds
|
#include <chrono> // for seconds
|
||||||
|
#include <cstdlib> // for exit
|
||||||
#include <memory> // for shared_ptr
|
#include <memory> // for shared_ptr
|
||||||
|
#include <mutex> // for unique_lock
|
||||||
#include <string> // for string
|
#include <string> // for string
|
||||||
#include <utility> // for move, forward
|
#include <utility> // for move, forward
|
||||||
|
|
||||||
#include "../common/common.h" // for AssertGPUSupport
|
#include "../common/common.h" // for AssertGPUSupport
|
||||||
|
#include "../common/json_utils.h" // for OptionalArg
|
||||||
#include "allgather.h" // for RingAllgather
|
#include "allgather.h" // for RingAllgather
|
||||||
#include "protocol.h" // for kMagic
|
#include "protocol.h" // for kMagic
|
||||||
|
#include "tracker.h" // for GetHostAddress
|
||||||
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
|
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
|
||||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
#include "xgboost/collective/socket.h" // for TCPSocket
|
||||||
#include "xgboost/json.h" // for Json, Object
|
#include "xgboost/json.h" // for Json, Object
|
||||||
@ -209,24 +213,18 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
|
|||||||
std::shared_ptr<TCPSocket> error_sock{TCPSocket::CreatePtr(domain)};
|
std::shared_ptr<TCPSocket> error_sock{TCPSocket::CreatePtr(domain)};
|
||||||
auto eport = error_sock->BindHost();
|
auto eport = error_sock->BindHost();
|
||||||
error_sock->Listen();
|
error_sock->Listen();
|
||||||
error_worker_ = std::thread{[this, error_sock = std::move(error_sock)] {
|
error_worker_ = std::thread{[error_sock = std::move(error_sock)] {
|
||||||
auto conn = error_sock->Accept();
|
auto conn = error_sock->Accept();
|
||||||
// On Windows accept returns an invalid socket after network is shutdown.
|
// On Windows, accept returns a closed socket after finalize.
|
||||||
if (conn.IsClosed()) {
|
if (conn.IsClosed()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
LOG(WARNING) << "Another worker is running into error.";
|
LOG(WARNING) << "Another worker is running into error.";
|
||||||
std::string scmd;
|
|
||||||
conn.Recv(&scmd);
|
|
||||||
auto jcmd = Json::Load(scmd);
|
|
||||||
auto rc = this->Shutdown();
|
|
||||||
if (!rc.OK()) {
|
|
||||||
LOG(WARNING) << "Fail to shutdown worker:" << rc.Report();
|
|
||||||
}
|
|
||||||
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
|
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
|
||||||
exit(-1);
|
// exit is nicer than abort as the former performs cleanups.
|
||||||
|
std::exit(-1);
|
||||||
#else
|
#else
|
||||||
LOG(FATAL) << rc.Report();
|
LOG(FATAL) << "abort";
|
||||||
#endif
|
#endif
|
||||||
}};
|
}};
|
||||||
error_worker_.detach();
|
error_worker_.detach();
|
||||||
|
|||||||
125
src/collective/comm_group.cc
Normal file
125
src/collective/comm_group.cc
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include "comm_group.h"
|
||||||
|
|
||||||
|
#include <algorithm> // for transform
|
||||||
|
#include <chrono> // for seconds
|
||||||
|
#include <cstdint> // for int32_t
|
||||||
|
#include <memory> // for shared_ptr, unique_ptr
|
||||||
|
#include <string> // for string
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
|
#include "../common/json_utils.h" // for OptionalArg
|
||||||
|
#include "coll.h" // for Coll
|
||||||
|
#include "comm.h" // for Comm
|
||||||
|
#include "tracker.h" // for GetHostAddress
|
||||||
|
#include "xgboost/collective/result.h" // for Result
|
||||||
|
#include "xgboost/context.h" // for DeviceOrd
|
||||||
|
#include "xgboost/json.h" // for Json
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_FEDERATED)
|
||||||
|
#include "../../plugin/federated/federated_coll.h"
|
||||||
|
#include "../../plugin/federated/federated_comm.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
[[nodiscard]] std::shared_ptr<Coll> CommGroup::Backend(DeviceOrd device) const {
|
||||||
|
if (device.IsCUDA()) {
|
||||||
|
if (!gpu_coll_) {
|
||||||
|
gpu_coll_.reset(backend_->MakeCUDAVar());
|
||||||
|
}
|
||||||
|
return gpu_coll_;
|
||||||
|
}
|
||||||
|
return backend_;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Comm const& CommGroup::Ctx(Context const* ctx, DeviceOrd device) const {
|
||||||
|
if (device.IsCUDA()) {
|
||||||
|
CHECK(ctx->IsCUDA());
|
||||||
|
if (!gpu_comm_) {
|
||||||
|
gpu_comm_.reset(comm_->MakeCUDAVar(ctx, backend_));
|
||||||
|
}
|
||||||
|
return *gpu_comm_;
|
||||||
|
}
|
||||||
|
return *comm_;
|
||||||
|
}
|
||||||
|
|
||||||
|
CommGroup::CommGroup()
|
||||||
|
: comm_{std::shared_ptr<RabitComm>(new RabitComm{})}, // NOLINT
|
||||||
|
backend_{std::shared_ptr<Coll>(new Coll{})} {} // NOLINT
|
||||||
|
|
||||||
|
[[nodiscard]] CommGroup* CommGroup::Create(Json config) {
|
||||||
|
if (IsA<Null>(config)) {
|
||||||
|
return new CommGroup;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string type = OptionalArg<String>(config, "dmlc_communicator", std::string{"rabit"});
|
||||||
|
std::vector<std::string> keys;
|
||||||
|
// Try both lower and upper case for compatibility
|
||||||
|
auto get_param = [&](std::string name, auto dft, auto t) {
|
||||||
|
std::string upper;
|
||||||
|
std::transform(name.cbegin(), name.cend(), std::back_inserter(upper),
|
||||||
|
[](char c) { return std::toupper(c); });
|
||||||
|
std::transform(name.cbegin(), name.cend(), name.begin(),
|
||||||
|
[](char c) { return std::tolower(c); });
|
||||||
|
keys.push_back(upper);
|
||||||
|
keys.push_back(name);
|
||||||
|
|
||||||
|
auto const& obj = get<Object const>(config);
|
||||||
|
auto it = obj.find(upper);
|
||||||
|
if (it != obj.cend()) {
|
||||||
|
return OptionalArg<decltype(t)>(config, upper, dft);
|
||||||
|
} else {
|
||||||
|
return OptionalArg<decltype(t)>(config, name, dft);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Common args
|
||||||
|
auto retry =
|
||||||
|
OptionalArg<Integer>(config, "dmlc_retry", static_cast<Integer::Int>(DefaultRetry()));
|
||||||
|
auto timeout = OptionalArg<Integer>(config, "dmlc_timeout_sec",
|
||||||
|
static_cast<Integer::Int>(DefaultTimeoutSec()));
|
||||||
|
auto task_id = get_param("dmlc_task_id", std::string{}, String{});
|
||||||
|
|
||||||
|
if (type == "rabit") {
|
||||||
|
auto host = get_param("dmlc_tracker_uri", std::string{}, String{});
|
||||||
|
auto port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{});
|
||||||
|
auto ptr =
|
||||||
|
new CommGroup{std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT
|
||||||
|
host, static_cast<std::int32_t>(port), std::chrono::seconds{timeout},
|
||||||
|
static_cast<std::int32_t>(retry), task_id}},
|
||||||
|
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
|
||||||
|
return ptr;
|
||||||
|
} else if (type == "federated") {
|
||||||
|
#if defined(XGBOOST_USE_FEDERATED)
|
||||||
|
auto ptr = new CommGroup{
|
||||||
|
std::make_shared<FederatedComm>(retry, std::chrono::seconds{timeout}, task_id, config),
|
||||||
|
std::make_shared<FederatedColl>()};
|
||||||
|
return ptr;
|
||||||
|
#endif // defined(XGBOOST_USE_FEDERATED)
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << "Invalid communicator type";
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<collective::CommGroup>& GlobalCommGroup() {
|
||||||
|
static std::unique_ptr<collective::CommGroup> sptr;
|
||||||
|
if (!sptr) {
|
||||||
|
Json config{Null{}};
|
||||||
|
sptr.reset(CommGroup::Create(config));
|
||||||
|
}
|
||||||
|
return sptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GlobalCommGroupInit(Json config) {
|
||||||
|
auto& sptr = GlobalCommGroup();
|
||||||
|
sptr.reset(CommGroup::Create(std::move(config)));
|
||||||
|
}
|
||||||
|
|
||||||
|
void GlobalCommGroupFinalize() {
|
||||||
|
auto& sptr = GlobalCommGroup();
|
||||||
|
sptr.reset();
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
53
src/collective/comm_group.h
Normal file
53
src/collective/comm_group.h
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <memory> // for shared_ptr, unique_ptr
|
||||||
|
#include <string> // for string
|
||||||
|
#include <utility> // for move
|
||||||
|
|
||||||
|
#include "coll.h" // for Comm
|
||||||
|
#include "comm.h" // for Coll
|
||||||
|
#include "xgboost/collective/result.h" // for Result
|
||||||
|
#include "xgboost/collective/socket.h" // for GetHostName
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
/**
|
||||||
|
* @brief Communicator group used for double dispatching between communicators and
|
||||||
|
* collective implementations.
|
||||||
|
*/
|
||||||
|
class CommGroup {
|
||||||
|
std::shared_ptr<Comm> comm_;
|
||||||
|
mutable std::shared_ptr<Comm> gpu_comm_;
|
||||||
|
|
||||||
|
std::shared_ptr<Coll> backend_;
|
||||||
|
mutable std::shared_ptr<Coll> gpu_coll_; // lazy initialization
|
||||||
|
|
||||||
|
CommGroup(std::shared_ptr<Comm> comm, std::shared_ptr<Coll> coll)
|
||||||
|
: comm_{std::move(comm)}, backend_{std::move(coll)} {}
|
||||||
|
|
||||||
|
public:
|
||||||
|
CommGroup();
|
||||||
|
|
||||||
|
[[nodiscard]] auto World() const { return comm_->World(); }
|
||||||
|
[[nodiscard]] auto Rank() const { return comm_->Rank(); }
|
||||||
|
[[nodiscard]] bool IsDistributed() const { return comm_->IsDistributed(); }
|
||||||
|
|
||||||
|
[[nodiscard]] static CommGroup* Create(Json config);
|
||||||
|
|
||||||
|
[[nodiscard]] std::shared_ptr<Coll> Backend(DeviceOrd device) const;
|
||||||
|
[[nodiscard]] Comm const& Ctx(Context const* ctx, DeviceOrd device) const;
|
||||||
|
[[nodiscard]] Result SignalError(Result const& res) { return comm_->SignalError(res); }
|
||||||
|
|
||||||
|
[[nodiscard]] Result ProcessorName(std::string* out) const {
|
||||||
|
auto rc = GetHostName(out);
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unique_ptr<collective::CommGroup>& GlobalCommGroup();
|
||||||
|
|
||||||
|
void GlobalCommGroupInit(Json config);
|
||||||
|
|
||||||
|
void GlobalCommGroupFinalize();
|
||||||
|
} // namespace xgboost::collective
|
||||||
@ -58,36 +58,35 @@ Result Tracker::WaitUntilReady() const {
|
|||||||
|
|
||||||
RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr)
|
RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr)
|
||||||
: sock_{std::move(sock)} {
|
: sock_{std::move(sock)} {
|
||||||
auto host = addr.Addr();
|
|
||||||
|
|
||||||
std::int32_t rank{0};
|
std::int32_t rank{0};
|
||||||
rc_ = Success()
|
Json jcmd;
|
||||||
<< [&] { return proto::Magic{}.Verify(&sock_); }
|
|
||||||
<< [&] { return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_); };
|
|
||||||
if (!rc_.OK()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string cmd;
|
|
||||||
sock_.Recv(&cmd);
|
|
||||||
auto jcmd = Json::Load(StringView{cmd});
|
|
||||||
cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"]));
|
|
||||||
std::int32_t port{0};
|
std::int32_t port{0};
|
||||||
if (cmd_ == proto::CMD::kStart) {
|
|
||||||
proto::Start start;
|
|
||||||
rc_ = start.TrackerHandle(jcmd, &world_, world, &port, &sock_, &eport_);
|
|
||||||
} else if (cmd_ == proto::CMD::kPrint) {
|
|
||||||
proto::Print print;
|
|
||||||
rc_ = print.TrackerHandle(jcmd, &msg_);
|
|
||||||
} else if (cmd_ == proto::CMD::kError) {
|
|
||||||
proto::ErrorCMD error;
|
|
||||||
rc_ = error.TrackerHandle(jcmd, &msg_, &code_);
|
|
||||||
}
|
|
||||||
if (!rc_.OK()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
info_ = proto::PeerInfo{host, port, rank};
|
rc_ = Success() << [&] { return proto::Magic{}.Verify(&sock_); } << [&] {
|
||||||
|
return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_);
|
||||||
|
} << [&] {
|
||||||
|
std::string cmd;
|
||||||
|
sock_.Recv(&cmd);
|
||||||
|
jcmd = Json::Load(StringView{cmd});
|
||||||
|
cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"]));
|
||||||
|
return Success();
|
||||||
|
} << [&] {
|
||||||
|
if (cmd_ == proto::CMD::kStart) {
|
||||||
|
proto::Start start;
|
||||||
|
return start.TrackerHandle(jcmd, &world_, world, &port, &sock_, &eport_);
|
||||||
|
} else if (cmd_ == proto::CMD::kPrint) {
|
||||||
|
proto::Print print;
|
||||||
|
return print.TrackerHandle(jcmd, &msg_);
|
||||||
|
} else if (cmd_ == proto::CMD::kError) {
|
||||||
|
proto::ErrorCMD error;
|
||||||
|
return error.TrackerHandle(jcmd, &msg_, &code_);
|
||||||
|
}
|
||||||
|
return Success();
|
||||||
|
} << [&] {
|
||||||
|
auto host = addr.Addr();
|
||||||
|
info_ = proto::PeerInfo{host, port, rank};
|
||||||
|
return Success();
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
|
RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
|
||||||
@ -137,15 +136,18 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
|
|
||||||
std::int32_t n_shutdown{0};
|
std::int32_t n_shutdown{0};
|
||||||
bool during_restart{false};
|
bool during_restart{false};
|
||||||
|
bool running{false};
|
||||||
std::vector<WorkerProxy> pending;
|
std::vector<WorkerProxy> pending;
|
||||||
|
|
||||||
explicit State(std::int32_t world) : n_workers{world} {}
|
explicit State(std::int32_t world) : n_workers{world} {}
|
||||||
State(State const& that) = delete;
|
State(State const& that) = delete;
|
||||||
State& operator=(State&& that) = delete;
|
State& operator=(State&& that) = delete;
|
||||||
|
|
||||||
|
// modifiers
|
||||||
void Start(WorkerProxy&& worker) {
|
void Start(WorkerProxy&& worker) {
|
||||||
CHECK_LT(pending.size(), n_workers);
|
CHECK_LT(pending.size(), n_workers);
|
||||||
CHECK_LE(n_shutdown, n_workers);
|
CHECK_LE(n_shutdown, n_workers);
|
||||||
|
CHECK(!running);
|
||||||
|
|
||||||
pending.emplace_back(std::forward<WorkerProxy>(worker));
|
pending.emplace_back(std::forward<WorkerProxy>(worker));
|
||||||
|
|
||||||
@ -155,6 +157,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
CHECK_GE(n_shutdown, 0);
|
CHECK_GE(n_shutdown, 0);
|
||||||
CHECK_LT(n_shutdown, n_workers);
|
CHECK_LT(n_shutdown, n_workers);
|
||||||
|
|
||||||
|
running = false;
|
||||||
++n_shutdown;
|
++n_shutdown;
|
||||||
|
|
||||||
CHECK_LE(n_shutdown, n_workers);
|
CHECK_LE(n_shutdown, n_workers);
|
||||||
@ -163,21 +166,26 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
CHECK_LE(pending.size(), n_workers);
|
CHECK_LE(pending.size(), n_workers);
|
||||||
CHECK_LE(n_shutdown, n_workers);
|
CHECK_LE(n_shutdown, n_workers);
|
||||||
|
|
||||||
|
running = false;
|
||||||
during_restart = true;
|
during_restart = true;
|
||||||
}
|
}
|
||||||
[[nodiscard]] bool Ready() const {
|
|
||||||
CHECK_LE(pending.size(), n_workers);
|
|
||||||
return static_cast<std::int32_t>(pending.size()) == n_workers;
|
|
||||||
}
|
|
||||||
void Bootstrap() {
|
void Bootstrap() {
|
||||||
CHECK_EQ(pending.size(), n_workers);
|
CHECK_EQ(pending.size(), n_workers);
|
||||||
CHECK_LE(n_shutdown, n_workers);
|
CHECK_LE(n_shutdown, n_workers);
|
||||||
|
|
||||||
|
running = true;
|
||||||
|
|
||||||
// A reset.
|
// A reset.
|
||||||
n_shutdown = 0;
|
n_shutdown = 0;
|
||||||
during_restart = false;
|
during_restart = false;
|
||||||
pending.clear();
|
pending.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// observers
|
||||||
|
[[nodiscard]] bool Ready() const {
|
||||||
|
CHECK_LE(pending.size(), n_workers);
|
||||||
|
return static_cast<std::int32_t>(pending.size()) == n_workers;
|
||||||
|
}
|
||||||
[[nodiscard]] bool ShouldContinue() const {
|
[[nodiscard]] bool ShouldContinue() const {
|
||||||
CHECK_LE(pending.size(), n_workers);
|
CHECK_LE(pending.size(), n_workers);
|
||||||
CHECK_LE(n_shutdown, n_workers);
|
CHECK_LE(n_shutdown, n_workers);
|
||||||
@ -187,7 +195,31 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
return std::async(std::launch::async, [this] {
|
auto handle_error = [&](WorkerProxy const& worker) {
|
||||||
|
auto msg = worker.Msg();
|
||||||
|
auto code = worker.Code();
|
||||||
|
LOG(WARNING) << "Recieved error from [" << worker.Host() << ":" << worker.Rank() << "]: " << msg
|
||||||
|
<< " code:" << code;
|
||||||
|
auto host = worker.Host();
|
||||||
|
// We signal all workers for the error, if they haven't aborted already.
|
||||||
|
for (auto& w : worker_error_handles_) {
|
||||||
|
if (w.first == host) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
TCPSocket out;
|
||||||
|
// Connecting to the error port as a signal for exit.
|
||||||
|
//
|
||||||
|
// retry is set to 1, just let the worker timeout or error. Otherwise the
|
||||||
|
// tracker and the worker might be waiting for each other.
|
||||||
|
auto rc = Connect(w.first, w.second, 1, timeout_, &out);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return Fail("Failed to inform workers to stop.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Success();
|
||||||
|
};
|
||||||
|
|
||||||
|
return std::async(std::launch::async, [this, handle_error] {
|
||||||
State state{this->n_workers_};
|
State state{this->n_workers_};
|
||||||
|
|
||||||
while (state.ShouldContinue()) {
|
while (state.ShouldContinue()) {
|
||||||
@ -205,6 +237,16 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
}
|
}
|
||||||
switch (worker.Command()) {
|
switch (worker.Command()) {
|
||||||
case proto::CMD::kStart: {
|
case proto::CMD::kStart: {
|
||||||
|
if (state.running) {
|
||||||
|
// Something went wrong with one of the workers. It got disconnected without
|
||||||
|
// notice.
|
||||||
|
state.Error();
|
||||||
|
rc = handle_error(worker);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return Fail("Failed to handle abort.", std::move(rc));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
state.Start(std::move(worker));
|
state.Start(std::move(worker));
|
||||||
if (state.Ready()) {
|
if (state.Ready()) {
|
||||||
rc = this->Bootstrap(&state.pending);
|
rc = this->Bootstrap(&state.pending);
|
||||||
@ -216,36 +258,20 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
case proto::CMD::kShutdown: {
|
case proto::CMD::kShutdown: {
|
||||||
|
if (state.during_restart) {
|
||||||
|
// The worker can still send shutdown after call to `std::exit`.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
state.Shutdown();
|
state.Shutdown();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
case proto::CMD::kError: {
|
case proto::CMD::kError: {
|
||||||
if (state.during_restart) {
|
if (state.during_restart) {
|
||||||
|
// Ignore further errors.
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
state.Error();
|
state.Error();
|
||||||
auto msg = worker.Msg();
|
rc = handle_error(worker);
|
||||||
auto code = worker.Code();
|
|
||||||
LOG(WARNING) << "Recieved error from [" << worker.Host() << ":" << worker.Rank()
|
|
||||||
<< "]: " << msg << " code:" << code;
|
|
||||||
auto host = worker.Host();
|
|
||||||
// We signal all workers for the error, if they haven't aborted already.
|
|
||||||
for (auto& w : worker_error_handles_) {
|
|
||||||
if (w.first == host) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
TCPSocket out;
|
|
||||||
// retry is set to 1, just let the worker timeout or error. Otherwise the
|
|
||||||
// tracker and the worker might be waiting for each other.
|
|
||||||
auto rc = Connect(w.first, w.second, 1, timeout_, &out);
|
|
||||||
// send signal to stop the worker.
|
|
||||||
proto::ShutdownCMD shutdown;
|
|
||||||
rc = shutdown.Send(&out);
|
|
||||||
if (!rc.OK()) {
|
|
||||||
return Fail("Failed to inform workers to stop.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
case proto::CMD::kPrint: {
|
case proto::CMD::kPrint: {
|
||||||
|
|||||||
63
tests/cpp/collective/test_comm_group.cc
Normal file
63
tests/cpp/collective/test_comm_group.cc
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/json.h> // for Json
|
||||||
|
|
||||||
|
#include <chrono> // for seconds
|
||||||
|
#include <cstdint> // for int32_t
|
||||||
|
#include <string> // for string
|
||||||
|
#include <thread> // for thread
|
||||||
|
|
||||||
|
#include "../../../src/collective/comm.h"
|
||||||
|
#include "../../../src/collective/comm_group.h"
|
||||||
|
#include "../../../src/common/common.h" // for AllVisibleGPUs
|
||||||
|
#include "../helpers.h" // for MakeCUDACtx
|
||||||
|
#include "test_worker.h" // for TestDistributed
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
namespace {
|
||||||
|
auto MakeConfig(std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) {
|
||||||
|
Json config{Object{}};
|
||||||
|
config["dmlc_communicator"] = std::string{"rabit"};
|
||||||
|
config["DMLC_TRACKER_URI"] = host;
|
||||||
|
config["DMLC_TRACKER_PORT"] = port;
|
||||||
|
config["dmlc_timeout_sec"] = static_cast<std::int64_t>(timeout.count());
|
||||||
|
config["DMLC_TASK_ID"] = std::to_string(r);
|
||||||
|
config["dmlc_retry"] = 2;
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
class CommGroupTest : public SocketTest {};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST_F(CommGroupTest, Basic) {
|
||||||
|
std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 5u);
|
||||||
|
TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||||
|
std::int32_t r) {
|
||||||
|
Context ctx;
|
||||||
|
auto config = MakeConfig(host, port, timeout, r);
|
||||||
|
std::unique_ptr<CommGroup> ptr{CommGroup::Create(config)};
|
||||||
|
ASSERT_TRUE(ptr->IsDistributed());
|
||||||
|
ASSERT_EQ(ptr->World(), n_workers);
|
||||||
|
auto const& comm = ptr->Ctx(&ctx, DeviceOrd::CPU());
|
||||||
|
ASSERT_EQ(comm.TaskID(), std::to_string(r));
|
||||||
|
ASSERT_EQ(comm.Retry(), 2);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_NCCL)
|
||||||
|
TEST_F(CommGroupTest, BasicGPU) {
|
||||||
|
std::int32_t n_workers = common::AllVisibleGPUs();
|
||||||
|
TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||||
|
std::int32_t r) {
|
||||||
|
auto ctx = MakeCUDACtx(r);
|
||||||
|
auto config = MakeConfig(host, port, timeout, r);
|
||||||
|
std::unique_ptr<CommGroup> ptr{CommGroup::Create(config)};
|
||||||
|
auto const& comm = ptr->Ctx(&ctx, DeviceOrd::CUDA(0));
|
||||||
|
ASSERT_EQ(comm.TaskID(), std::to_string(r));
|
||||||
|
ASSERT_EQ(comm.Retry(), 2);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
#endif // for defined(XGBOOST_USE_NCCL)
|
||||||
|
} // namespace xgboost::collective
|
||||||
@ -95,7 +95,8 @@ void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) {
|
|||||||
std::chrono::seconds timeout{1};
|
std::chrono::seconds timeout{1};
|
||||||
|
|
||||||
std::string host;
|
std::string host;
|
||||||
ASSERT_TRUE(GetHostAddress(&host).OK());
|
auto rc = GetHostAddress(&host);
|
||||||
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
RabitTracker tracker{StringView{host}, n_workers, 0, timeout};
|
RabitTracker tracker{StringView{host}, n_workers, 0, timeout};
|
||||||
auto fut = tracker.Run();
|
auto fut = tracker.Run();
|
||||||
|
|
||||||
|
|||||||
@ -15,6 +15,15 @@
|
|||||||
namespace xgboost::linalg {
|
namespace xgboost::linalg {
|
||||||
namespace {
|
namespace {
|
||||||
DeviceOrd CPU() { return DeviceOrd::CPU(); }
|
DeviceOrd CPU() { return DeviceOrd::CPU(); }
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void ConstView(linalg::VectorView<T> v1, linalg::VectorView<std::add_const_t<T>> v2) {
|
||||||
|
// compile test for being able to pass non-const view to const view.
|
||||||
|
auto s = v1.Slice(linalg::All());
|
||||||
|
ASSERT_EQ(s.Size(), v1.Size());
|
||||||
|
auto s2 = v2.Slice(linalg::All());
|
||||||
|
ASSERT_EQ(s2.Size(), v2.Size());
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
auto MakeMatrixFromTest(HostDeviceVector<float> *storage, std::size_t n_rows, std::size_t n_cols) {
|
auto MakeMatrixFromTest(HostDeviceVector<float> *storage, std::size_t n_rows, std::size_t n_cols) {
|
||||||
@ -206,6 +215,11 @@ TEST(Linalg, TensorView) {
|
|||||||
ASSERT_TRUE(t.FContiguous());
|
ASSERT_TRUE(t.FContiguous());
|
||||||
ASSERT_FALSE(t.CContiguous());
|
ASSERT_FALSE(t.CContiguous());
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
// const
|
||||||
|
TensorView<double, 1> t{data, {data.size()}, CPU()};
|
||||||
|
ConstView(t, t);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Linalg, Tensor) {
|
TEST(Linalg, Tensor) {
|
||||||
|
|||||||
@ -124,6 +124,9 @@ TEST_F(FederatedCollTestGPU, Allgather) {
|
|||||||
|
|
||||||
TEST_F(FederatedCollTestGPU, AllgatherV) {
|
TEST_F(FederatedCollTestGPU, AllgatherV) {
|
||||||
std::int32_t n_workers = 2;
|
std::int32_t n_workers = 2;
|
||||||
|
if (common::AllVisibleGPUs() < n_workers) {
|
||||||
|
GTEST_SKIP_("At least 2 GPUs are required for the test.");
|
||||||
|
}
|
||||||
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
|
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
|
||||||
TestAllgatherV(comm, rank);
|
TestAllgatherV(comm, rank);
|
||||||
});
|
});
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2022-2023, XGBoost contributors
|
* Copyright 2022-2023, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
|
#include <gmock/gmock.h>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <string> // for string
|
#include <string> // for string
|
||||||
@ -19,12 +20,14 @@ class FederatedCommTest : public SocketTest {};
|
|||||||
|
|
||||||
TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) {
|
TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) {
|
||||||
auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; };
|
auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; };
|
||||||
ExpectThrow<dmlc::Error>("Invalid world size.", construct);
|
ASSERT_THAT(construct,
|
||||||
|
::testing::ThrowsMessage<dmlc::Error>(::testing::HasSubstr("Invalid world size")));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(FederatedCommTest, ThrowOnRankTooSmall) {
|
TEST_F(FederatedCommTest, ThrowOnRankTooSmall) {
|
||||||
auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; };
|
auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; };
|
||||||
ExpectThrow<dmlc::Error>("Invalid worker rank.", construct);
|
ASSERT_THAT(construct,
|
||||||
|
::testing::ThrowsMessage<dmlc::Error>(::testing::HasSubstr("Invalid worker rank.")));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(FederatedCommTest, ThrowOnRankTooBig) {
|
TEST_F(FederatedCommTest, ThrowOnRankTooBig) {
|
||||||
@ -38,7 +41,7 @@ TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) {
|
|||||||
config["federated_server_address"] = std::string("localhost:0");
|
config["federated_server_address"] = std::string("localhost:0");
|
||||||
config["federated_world_size"] = std::string("1");
|
config["federated_world_size"] = std::string("1");
|
||||||
config["federated_rank"] = Integer(0);
|
config["federated_rank"] = Integer(0);
|
||||||
FederatedComm comm(config);
|
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
|
||||||
};
|
};
|
||||||
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
||||||
}
|
}
|
||||||
@ -49,7 +52,7 @@ TEST_F(FederatedCommTest, ThrowOnRankNotInteger) {
|
|||||||
config["federated_server_address"] = std::string("localhost:0");
|
config["federated_server_address"] = std::string("localhost:0");
|
||||||
config["federated_world_size"] = 1;
|
config["federated_world_size"] = 1;
|
||||||
config["federated_rank"] = std::string("0");
|
config["federated_rank"] = std::string("0");
|
||||||
FederatedComm comm(config);
|
FederatedComm comm(DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config);
|
||||||
};
|
};
|
||||||
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
||||||
}
|
}
|
||||||
@ -59,7 +62,7 @@ TEST_F(FederatedCommTest, GetWorldSizeAndRank) {
|
|||||||
config["federated_world_size"] = 6;
|
config["federated_world_size"] = 6;
|
||||||
config["federated_rank"] = 3;
|
config["federated_rank"] = 3;
|
||||||
config["federated_server_address"] = String{"localhost:0"};
|
config["federated_server_address"] = String{"localhost:0"};
|
||||||
FederatedComm comm{config};
|
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
|
||||||
EXPECT_EQ(comm.World(), 6);
|
EXPECT_EQ(comm.World(), 6);
|
||||||
EXPECT_EQ(comm.Rank(), 3);
|
EXPECT_EQ(comm.Rank(), 3);
|
||||||
}
|
}
|
||||||
|
|||||||
22
tests/cpp/plugin/federated/test_federated_comm_group.cc
Normal file
22
tests/cpp/plugin/federated/test_federated_comm_group.cc
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/json.h> // for Json
|
||||||
|
|
||||||
|
#include "../../../../src/collective/comm_group.h"
|
||||||
|
#include "../../helpers.h"
|
||||||
|
#include "test_worker.h"
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
TEST(CommGroup, Federated) {
|
||||||
|
std::int32_t n_workers = common::AllVisibleGPUs();
|
||||||
|
TestFederatedGroup(n_workers, [&](std::shared_ptr<CommGroup> comm_group, std::int32_t r) {
|
||||||
|
Context ctx;
|
||||||
|
ASSERT_EQ(comm_group->Rank(), r);
|
||||||
|
auto const& comm = comm_group->Ctx(&ctx, DeviceOrd::CPU());
|
||||||
|
ASSERT_EQ(comm.TaskID(), std::to_string(r));
|
||||||
|
ASSERT_EQ(comm.Retry(), 2);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
22
tests/cpp/plugin/federated/test_federated_comm_group.cu
Normal file
22
tests/cpp/plugin/federated/test_federated_comm_group.cu
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/json.h> // for Json
|
||||||
|
|
||||||
|
#include "../../../../src/collective/comm_group.h"
|
||||||
|
#include "../../helpers.h"
|
||||||
|
#include "test_worker.h"
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
TEST(CommGroup, FederatedGPU) {
|
||||||
|
std::int32_t n_workers = common::AllVisibleGPUs();
|
||||||
|
TestFederatedGroup(n_workers, [&](std::shared_ptr<CommGroup> comm_group, std::int32_t r) {
|
||||||
|
Context ctx = MakeCUDACtx(0);
|
||||||
|
auto const& comm = comm_group->Ctx(&ctx, DeviceOrd::CUDA(0));
|
||||||
|
ASSERT_EQ(comm_group->Rank(), r);
|
||||||
|
ASSERT_EQ(comm.TaskID(), std::to_string(r));
|
||||||
|
ASSERT_EQ(comm.Retry(), 2);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
@ -5,10 +5,12 @@
|
|||||||
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <chrono> // for ms
|
#include <chrono> // for ms, seconds
|
||||||
|
#include <memory> // for shared_ptr
|
||||||
#include <thread> // for thread
|
#include <thread> // for thread
|
||||||
|
|
||||||
#include "../../../../plugin/federated/federated_tracker.h"
|
#include "../../../../plugin/federated/federated_tracker.h"
|
||||||
|
#include "../../../../src/collective/comm_group.h"
|
||||||
#include "federated_comm.h" // for FederatedComm
|
#include "federated_comm.h" // for FederatedComm
|
||||||
#include "xgboost/json.h" // for Json
|
#include "xgboost/json.h" // for Json
|
||||||
|
|
||||||
@ -23,9 +25,8 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
|
|||||||
|
|
||||||
std::vector<std::thread> workers;
|
std::vector<std::thread> workers;
|
||||||
using namespace std::chrono_literals;
|
using namespace std::chrono_literals;
|
||||||
while (tracker.Port() == 0) {
|
auto rc = tracker.WaitUntilReady();
|
||||||
std::this_thread::sleep_for(100ms);
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
}
|
|
||||||
std::int32_t port = tracker.Port();
|
std::int32_t port = tracker.Port();
|
||||||
|
|
||||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||||
@ -34,7 +35,8 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
|
|||||||
config["federated_world_size"] = n_workers;
|
config["federated_world_size"] = n_workers;
|
||||||
config["federated_rank"] = i;
|
config["federated_rank"] = i;
|
||||||
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
|
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
|
||||||
auto comm = std::make_shared<FederatedComm>(config);
|
auto comm = std::make_shared<FederatedComm>(
|
||||||
|
DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, std::to_string(i), config);
|
||||||
|
|
||||||
fn(comm, i);
|
fn(comm, i);
|
||||||
});
|
});
|
||||||
@ -44,7 +46,43 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
|
|||||||
t.join();
|
t.join();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto rc = tracker.Shutdown();
|
rc = tracker.Shutdown();
|
||||||
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
|
ASSERT_TRUE(fut.get().OK());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename WorkerFn>
|
||||||
|
void TestFederatedGroup(std::int32_t n_workers, WorkerFn&& fn) {
|
||||||
|
Json config{Object()};
|
||||||
|
config["federated_secure"] = Boolean{false};
|
||||||
|
config["n_workers"] = Integer{n_workers};
|
||||||
|
FederatedTracker tracker{config};
|
||||||
|
auto fut = tracker.Run();
|
||||||
|
|
||||||
|
std::vector<std::thread> workers;
|
||||||
|
auto rc = tracker.WaitUntilReady();
|
||||||
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
|
std::int32_t port = tracker.Port();
|
||||||
|
|
||||||
|
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||||
|
workers.emplace_back([=] {
|
||||||
|
Json config{Object{}};
|
||||||
|
config["dmlc_communicator"] = std::string{"federated"};
|
||||||
|
config["dmlc_task_id"] = std::to_string(i);
|
||||||
|
config["dmlc_retry"] = 2;
|
||||||
|
config["federated_world_size"] = n_workers;
|
||||||
|
config["federated_rank"] = i;
|
||||||
|
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
|
||||||
|
std::shared_ptr<CommGroup> comm_group{CommGroup::Create(config)};
|
||||||
|
fn(comm_group, i);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& t : workers) {
|
||||||
|
t.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
rc = tracker.Shutdown();
|
||||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
ASSERT_TRUE(fut.get().OK());
|
ASSERT_TRUE(fut.get().OK());
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user