[coll] Add comm group. (#9759)

- Implement `CommGroup` for double dispatching.
- Small cleanup to tracker for handling abort.
This commit is contained in:
Jiaming Yuan
2023-11-07 11:12:31 +08:00
committed by GitHub
parent c3a0622b49
commit 6c0a190f6d
15 changed files with 462 additions and 79 deletions

View File

@@ -5,13 +5,17 @@
#include <algorithm> // for copy
#include <chrono> // for seconds
#include <cstdlib> // for exit
#include <memory> // for shared_ptr
#include <mutex> // for unique_lock
#include <string> // for string
#include <utility> // for move, forward
#include "../common/common.h" // for AssertGPUSupport
#include "../common/json_utils.h" // for OptionalArg
#include "allgather.h" // for RingAllgather
#include "protocol.h" // for kMagic
#include "tracker.h" // for GetHostAddress
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
#include "xgboost/collective/socket.h" // for TCPSocket
#include "xgboost/json.h" // for Json, Object
@@ -209,24 +213,18 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
std::shared_ptr<TCPSocket> error_sock{TCPSocket::CreatePtr(domain)};
auto eport = error_sock->BindHost();
error_sock->Listen();
error_worker_ = std::thread{[this, error_sock = std::move(error_sock)] {
error_worker_ = std::thread{[error_sock = std::move(error_sock)] {
auto conn = error_sock->Accept();
// On Windows accept returns an invalid socket after network is shutdown.
// On Windows, accept returns a closed socket after finalize.
if (conn.IsClosed()) {
return;
}
LOG(WARNING) << "Another worker is running into error.";
std::string scmd;
conn.Recv(&scmd);
auto jcmd = Json::Load(scmd);
auto rc = this->Shutdown();
if (!rc.OK()) {
LOG(WARNING) << "Fail to shutdown worker:" << rc.Report();
}
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
exit(-1);
// exit is nicer than abort as the former performs cleanups.
std::exit(-1);
#else
LOG(FATAL) << rc.Report();
LOG(FATAL) << "abort";
#endif
}};
error_worker_.detach();

View File

@@ -0,0 +1,125 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include "comm_group.h"
#include <algorithm> // for transform
#include <chrono> // for seconds
#include <cstdint> // for int32_t
#include <memory> // for shared_ptr, unique_ptr
#include <string> // for string
#include <vector> // for vector
#include "../common/json_utils.h" // for OptionalArg
#include "coll.h" // for Coll
#include "comm.h" // for Comm
#include "tracker.h" // for GetHostAddress
#include "xgboost/collective/result.h" // for Result
#include "xgboost/context.h" // for DeviceOrd
#include "xgboost/json.h" // for Json
#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_coll.h"
#include "../../plugin/federated/federated_comm.h"
#endif
namespace xgboost::collective {
[[nodiscard]] std::shared_ptr<Coll> CommGroup::Backend(DeviceOrd device) const {
if (device.IsCUDA()) {
if (!gpu_coll_) {
gpu_coll_.reset(backend_->MakeCUDAVar());
}
return gpu_coll_;
}
return backend_;
}
[[nodiscard]] Comm const& CommGroup::Ctx(Context const* ctx, DeviceOrd device) const {
if (device.IsCUDA()) {
CHECK(ctx->IsCUDA());
if (!gpu_comm_) {
gpu_comm_.reset(comm_->MakeCUDAVar(ctx, backend_));
}
return *gpu_comm_;
}
return *comm_;
}
CommGroup::CommGroup()
: comm_{std::shared_ptr<RabitComm>(new RabitComm{})}, // NOLINT
backend_{std::shared_ptr<Coll>(new Coll{})} {} // NOLINT
[[nodiscard]] CommGroup* CommGroup::Create(Json config) {
if (IsA<Null>(config)) {
return new CommGroup;
}
std::string type = OptionalArg<String>(config, "dmlc_communicator", std::string{"rabit"});
std::vector<std::string> keys;
// Try both lower and upper case for compatibility
auto get_param = [&](std::string name, auto dft, auto t) {
std::string upper;
std::transform(name.cbegin(), name.cend(), std::back_inserter(upper),
[](char c) { return std::toupper(c); });
std::transform(name.cbegin(), name.cend(), name.begin(),
[](char c) { return std::tolower(c); });
keys.push_back(upper);
keys.push_back(name);
auto const& obj = get<Object const>(config);
auto it = obj.find(upper);
if (it != obj.cend()) {
return OptionalArg<decltype(t)>(config, upper, dft);
} else {
return OptionalArg<decltype(t)>(config, name, dft);
}
};
// Common args
auto retry =
OptionalArg<Integer>(config, "dmlc_retry", static_cast<Integer::Int>(DefaultRetry()));
auto timeout = OptionalArg<Integer>(config, "dmlc_timeout_sec",
static_cast<Integer::Int>(DefaultTimeoutSec()));
auto task_id = get_param("dmlc_task_id", std::string{}, String{});
if (type == "rabit") {
auto host = get_param("dmlc_tracker_uri", std::string{}, String{});
auto port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{});
auto ptr =
new CommGroup{std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT
host, static_cast<std::int32_t>(port), std::chrono::seconds{timeout},
static_cast<std::int32_t>(retry), task_id}},
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
return ptr;
} else if (type == "federated") {
#if defined(XGBOOST_USE_FEDERATED)
auto ptr = new CommGroup{
std::make_shared<FederatedComm>(retry, std::chrono::seconds{timeout}, task_id, config),
std::make_shared<FederatedColl>()};
return ptr;
#endif // defined(XGBOOST_USE_FEDERATED)
} else {
LOG(FATAL) << "Invalid communicator type";
}
return nullptr;
}
std::unique_ptr<collective::CommGroup>& GlobalCommGroup() {
static std::unique_ptr<collective::CommGroup> sptr;
if (!sptr) {
Json config{Null{}};
sptr.reset(CommGroup::Create(config));
}
return sptr;
}
void GlobalCommGroupInit(Json config) {
auto& sptr = GlobalCommGroup();
sptr.reset(CommGroup::Create(std::move(config)));
}
void GlobalCommGroupFinalize() {
auto& sptr = GlobalCommGroup();
sptr.reset();
}
} // namespace xgboost::collective

View File

@@ -0,0 +1,53 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <memory> // for shared_ptr, unique_ptr
#include <string> // for string
#include <utility> // for move
#include "coll.h" // for Comm
#include "comm.h" // for Coll
#include "xgboost/collective/result.h" // for Result
#include "xgboost/collective/socket.h" // for GetHostName
namespace xgboost::collective {
/**
* @brief Communicator group used for double dispatching between communicators and
* collective implementations.
*/
class CommGroup {
std::shared_ptr<Comm> comm_;
mutable std::shared_ptr<Comm> gpu_comm_;
std::shared_ptr<Coll> backend_;
mutable std::shared_ptr<Coll> gpu_coll_; // lazy initialization
CommGroup(std::shared_ptr<Comm> comm, std::shared_ptr<Coll> coll)
: comm_{std::move(comm)}, backend_{std::move(coll)} {}
public:
CommGroup();
[[nodiscard]] auto World() const { return comm_->World(); }
[[nodiscard]] auto Rank() const { return comm_->Rank(); }
[[nodiscard]] bool IsDistributed() const { return comm_->IsDistributed(); }
[[nodiscard]] static CommGroup* Create(Json config);
[[nodiscard]] std::shared_ptr<Coll> Backend(DeviceOrd device) const;
[[nodiscard]] Comm const& Ctx(Context const* ctx, DeviceOrd device) const;
[[nodiscard]] Result SignalError(Result const& res) { return comm_->SignalError(res); }
[[nodiscard]] Result ProcessorName(std::string* out) const {
auto rc = GetHostName(out);
return rc;
}
};
std::unique_ptr<collective::CommGroup>& GlobalCommGroup();
void GlobalCommGroupInit(Json config);
void GlobalCommGroupFinalize();
} // namespace xgboost::collective

View File

@@ -58,36 +58,35 @@ Result Tracker::WaitUntilReady() const {
RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr)
: sock_{std::move(sock)} {
auto host = addr.Addr();
std::int32_t rank{0};
rc_ = Success()
<< [&] { return proto::Magic{}.Verify(&sock_); }
<< [&] { return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_); };
if (!rc_.OK()) {
return;
}
std::string cmd;
sock_.Recv(&cmd);
auto jcmd = Json::Load(StringView{cmd});
cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"]));
Json jcmd;
std::int32_t port{0};
if (cmd_ == proto::CMD::kStart) {
proto::Start start;
rc_ = start.TrackerHandle(jcmd, &world_, world, &port, &sock_, &eport_);
} else if (cmd_ == proto::CMD::kPrint) {
proto::Print print;
rc_ = print.TrackerHandle(jcmd, &msg_);
} else if (cmd_ == proto::CMD::kError) {
proto::ErrorCMD error;
rc_ = error.TrackerHandle(jcmd, &msg_, &code_);
}
if (!rc_.OK()) {
return;
}
info_ = proto::PeerInfo{host, port, rank};
rc_ = Success() << [&] { return proto::Magic{}.Verify(&sock_); } << [&] {
return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_);
} << [&] {
std::string cmd;
sock_.Recv(&cmd);
jcmd = Json::Load(StringView{cmd});
cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"]));
return Success();
} << [&] {
if (cmd_ == proto::CMD::kStart) {
proto::Start start;
return start.TrackerHandle(jcmd, &world_, world, &port, &sock_, &eport_);
} else if (cmd_ == proto::CMD::kPrint) {
proto::Print print;
return print.TrackerHandle(jcmd, &msg_);
} else if (cmd_ == proto::CMD::kError) {
proto::ErrorCMD error;
return error.TrackerHandle(jcmd, &msg_, &code_);
}
return Success();
} << [&] {
auto host = addr.Addr();
info_ = proto::PeerInfo{host, port, rank};
return Success();
};
}
RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
@@ -137,15 +136,18 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
std::int32_t n_shutdown{0};
bool during_restart{false};
bool running{false};
std::vector<WorkerProxy> pending;
explicit State(std::int32_t world) : n_workers{world} {}
State(State const& that) = delete;
State& operator=(State&& that) = delete;
// modifiers
void Start(WorkerProxy&& worker) {
CHECK_LT(pending.size(), n_workers);
CHECK_LE(n_shutdown, n_workers);
CHECK(!running);
pending.emplace_back(std::forward<WorkerProxy>(worker));
@@ -155,6 +157,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
CHECK_GE(n_shutdown, 0);
CHECK_LT(n_shutdown, n_workers);
running = false;
++n_shutdown;
CHECK_LE(n_shutdown, n_workers);
@@ -163,21 +166,26 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
CHECK_LE(pending.size(), n_workers);
CHECK_LE(n_shutdown, n_workers);
running = false;
during_restart = true;
}
[[nodiscard]] bool Ready() const {
CHECK_LE(pending.size(), n_workers);
return static_cast<std::int32_t>(pending.size()) == n_workers;
}
void Bootstrap() {
CHECK_EQ(pending.size(), n_workers);
CHECK_LE(n_shutdown, n_workers);
running = true;
// A reset.
n_shutdown = 0;
during_restart = false;
pending.clear();
}
// observers
[[nodiscard]] bool Ready() const {
CHECK_LE(pending.size(), n_workers);
return static_cast<std::int32_t>(pending.size()) == n_workers;
}
[[nodiscard]] bool ShouldContinue() const {
CHECK_LE(pending.size(), n_workers);
CHECK_LE(n_shutdown, n_workers);
@@ -187,7 +195,31 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
}
};
return std::async(std::launch::async, [this] {
auto handle_error = [&](WorkerProxy const& worker) {
auto msg = worker.Msg();
auto code = worker.Code();
LOG(WARNING) << "Recieved error from [" << worker.Host() << ":" << worker.Rank() << "]: " << msg
<< " code:" << code;
auto host = worker.Host();
// We signal all workers for the error, if they haven't aborted already.
for (auto& w : worker_error_handles_) {
if (w.first == host) {
continue;
}
TCPSocket out;
// Connecting to the error port as a signal for exit.
//
// retry is set to 1, just let the worker timeout or error. Otherwise the
// tracker and the worker might be waiting for each other.
auto rc = Connect(w.first, w.second, 1, timeout_, &out);
if (!rc.OK()) {
return Fail("Failed to inform workers to stop.");
}
}
return Success();
};
return std::async(std::launch::async, [this, handle_error] {
State state{this->n_workers_};
while (state.ShouldContinue()) {
@@ -205,6 +237,16 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
}
switch (worker.Command()) {
case proto::CMD::kStart: {
if (state.running) {
// Something went wrong with one of the workers. It got disconnected without
// notice.
state.Error();
rc = handle_error(worker);
if (!rc.OK()) {
return Fail("Failed to handle abort.", std::move(rc));
}
}
state.Start(std::move(worker));
if (state.Ready()) {
rc = this->Bootstrap(&state.pending);
@@ -216,36 +258,20 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
continue;
}
case proto::CMD::kShutdown: {
if (state.during_restart) {
// The worker can still send shutdown after call to `std::exit`.
continue;
}
state.Shutdown();
continue;
}
case proto::CMD::kError: {
if (state.during_restart) {
// Ignore further errors.
continue;
}
state.Error();
auto msg = worker.Msg();
auto code = worker.Code();
LOG(WARNING) << "Recieved error from [" << worker.Host() << ":" << worker.Rank()
<< "]: " << msg << " code:" << code;
auto host = worker.Host();
// We signal all workers for the error, if they haven't aborted already.
for (auto& w : worker_error_handles_) {
if (w.first == host) {
continue;
}
TCPSocket out;
// retry is set to 1, just let the worker timeout or error. Otherwise the
// tracker and the worker might be waiting for each other.
auto rc = Connect(w.first, w.second, 1, timeout_, &out);
// send signal to stop the worker.
proto::ShutdownCMD shutdown;
rc = shutdown.Send(&out);
if (!rc.OK()) {
return Fail("Failed to inform workers to stop.");
}
}
rc = handle_error(worker);
continue;
}
case proto::CMD::kPrint: {