[coll] Implement shutdown for tracker and comm. (#10208)

- Force shutdown the tracker.
- Implement shutdown notice for error handling thread in comm.
This commit is contained in:
Jiaming Yuan
2024-04-20 04:08:17 +08:00
committed by GitHub
parent 8fb05c8c95
commit 3fbb221fec
24 changed files with 553 additions and 199 deletions

View File

@@ -5,9 +5,11 @@
#include <future> // for future
#include <memory> // for unique_ptr
#include <string> // for string
#include <thread> // for sleep_for
#include <type_traits> // for is_same_v, remove_pointer_t
#include <utility> // for pair
#include "../collective/comm.h" // for DefaultTimeoutSec
#include "../collective/tracker.h" // for RabitTracker
#include "../common/timer.h" // for Timer
#include "c_api_error.h" // for API_BEGIN
@@ -26,7 +28,7 @@ using namespace xgboost; // NOLINT
namespace {
using TrackerHandleT =
std::pair<std::unique_ptr<collective::Tracker>, std::shared_future<collective::Result>>;
std::pair<std::shared_ptr<collective::Tracker>, std::shared_future<collective::Result>>;
TrackerHandleT *GetTrackerHandle(TrackerHandle handle) {
xgboost_CHECK_C_ARG_PTR(handle);
@@ -41,12 +43,14 @@ struct CollAPIEntry {
using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
constexpr std::int64_t kDft{60};
constexpr std::int64_t kDft{collective::DefaultTimeoutSec()};
std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft};
common::Timer timer;
timer.Start();
auto ref = ptr->first; // hold a reference to that free don't delete it while waiting.
auto fut = ptr->second;
while (fut.valid()) {
auto res = fut.wait_for(wait_for);
@@ -72,15 +76,15 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle) {
Json jconfig = Json::Load(config);
auto type = RequiredArg<String>(jconfig, "dmlc_communicator", __func__);
std::unique_ptr<collective::Tracker> tptr;
std::shared_ptr<collective::Tracker> tptr;
if (type == "federated") {
#if defined(XGBOOST_USE_FEDERATED)
tptr = std::make_unique<collective::FederatedTracker>(jconfig);
tptr = std::make_shared<collective::FederatedTracker>(jconfig);
#else
LOG(FATAL) << error::NoFederated();
#endif // defined(XGBOOST_USE_FEDERATED)
} else if (type == "rabit") {
tptr = std::make_unique<collective::RabitTracker>(jconfig);
tptr = std::make_shared<collective::RabitTracker>(jconfig);
} else {
LOG(FATAL) << "Unknown communicator:" << type;
}
@@ -103,7 +107,7 @@ XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args) {
API_END();
}
XGB_DLL int XGTrackerRun(TrackerHandle handle) {
XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *) {
API_BEGIN();
auto *ptr = GetTrackerHandle(handle);
CHECK(!ptr->second.valid()) << "Tracker is already running.";
@@ -111,13 +115,14 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle) {
API_END();
}
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) {
XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config) {
API_BEGIN();
auto *ptr = GetTrackerHandle(handle);
xgboost_CHECK_C_ARG_PTR(config);
auto jconfig = Json::Load(StringView{config});
// Internally, 0 indicates no timeout, which is the default since we don't want to
// interrupt the model training.
xgboost_CHECK_C_ARG_PTR(config);
auto timeout = OptionalArg<Integer>(jconfig, "timeout", std::int64_t{0});
WaitImpl(ptr, std::chrono::seconds{timeout});
API_END();
@@ -125,8 +130,24 @@ XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) {
XGB_DLL int XGTrackerFree(TrackerHandle handle) {
API_BEGIN();
using namespace std::chrono_literals; // NOLINT
auto *ptr = GetTrackerHandle(handle);
ptr->first->Stop();
// The wait is not necessary since we just called stop, just reusing the function to do
// any potential cleanups.
WaitImpl(ptr, ptr->first->Timeout());
common::Timer timer;
timer.Start();
// Make sure no one else is waiting on the tracker.
while (!ptr->first.unique()) {
auto ela = timer.Duration().count();
if (ela > ptr->first->Timeout().count()) {
LOG(WARNING) << "Time out " << ptr->first->Timeout().count()
<< " seconds reached for TrackerFree, killing the tracker.";
break;
}
std::this_thread::sleep_for(64ms);
}
delete ptr;
API_END();
}

View File

@@ -38,6 +38,10 @@ bool constexpr IsFloatingPointV() {
auto redop_fn = [](auto lhs, auto out, auto elem_op) {
auto p_lhs = lhs.data();
auto p_out = out.data();
#if defined(__GNUC__) || defined(__clang__)
// For the sum op, one can verify the simd by: addps %xmm15, %xmm14
#pragma omp simd
#endif
for (std::size_t i = 0; i < lhs.size(); ++i) {
p_out[i] = elem_op(p_lhs[i], p_out[i]);
}

View File

@@ -5,9 +5,11 @@
#include <algorithm> // for copy
#include <chrono> // for seconds
#include <cstdint> // for int32_t
#include <cstdlib> // for exit
#include <memory> // for shared_ptr
#include <string> // for string
#include <thread> // for thread
#include <utility> // for move, forward
#if !defined(XGBOOST_USE_NCCL)
#include "../common/common.h" // for AssertNCCLSupport
@@ -184,13 +186,30 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
return Success();
}
RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t retry, std::string task_id, StringView nccl_path)
: HostComm{std::move(host), port, timeout, retry, std::move(task_id)},
namespace {
std::string InitLog(std::string task_id, std::int32_t rank) {
if (task_id.empty()) {
return "Rank " + std::to_string(rank);
}
return "Task " + task_id + " got rank " + std::to_string(rank);
}
} // namespace
RabitComm::RabitComm(std::string const& tracker_host, std::int32_t tracker_port,
std::chrono::seconds timeout, std::int32_t retry, std::string task_id,
StringView nccl_path)
: HostComm{tracker_host, tracker_port, timeout, retry, std::move(task_id)},
nccl_path_{std::move(nccl_path)} {
if (this->TrackerInfo().host.empty()) {
// Not in a distributed environment.
LOG(CONSOLE) << InitLog(task_id_, rank_);
return;
}
loop_.reset(new Loop{std::chrono::seconds{timeout_}}); // NOLINT
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
if (!rc.OK()) {
this->ResetState();
SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc)));
}
}
@@ -217,20 +236,54 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
// Start command
TCPSocket listener = TCPSocket::Create(tracker.Domain());
std::int32_t lport = listener.BindHost();
listener.Listen();
std::int32_t lport{0};
rc = std::move(rc) << [&] {
return listener.BindHost(&lport);
} << [&] {
return listener.Listen();
};
if (!rc.OK()) {
return rc;
}
// create worker for listening to error notice.
auto domain = tracker.Domain();
std::shared_ptr<TCPSocket> error_sock{TCPSocket::CreatePtr(domain)};
auto eport = error_sock->BindHost();
error_sock->Listen();
std::int32_t eport{0};
rc = std::move(rc) << [&] {
return error_sock->BindHost(&eport);
} << [&] {
return error_sock->Listen();
};
if (!rc.OK()) {
return rc;
}
error_port_ = eport;
error_worker_ = std::thread{[error_sock = std::move(error_sock)] {
auto conn = error_sock->Accept();
TCPSocket conn;
SockAddress addr;
auto rc = error_sock->Accept(&conn, &addr);
// On Linux, a shutdown causes an invalid argument error;
if (rc.Code() == std::errc::invalid_argument) {
return;
}
// On Windows, accept returns a closed socket after finalize.
if (conn.IsClosed()) {
return;
}
// The error signal is from the tracker, while shutdown signal is from the shutdown method
// of the RabitComm class (this).
bool is_error{false};
rc = proto::Error{}.RecvSignal(&conn, &is_error);
if (!rc.OK()) {
LOG(WARNING) << rc.Report();
return;
}
if (!is_error) {
return; // shutdown
}
LOG(WARNING) << "Another worker is running into error.";
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
// exit is nicer than abort as the former performs cleanups.
@@ -239,6 +292,9 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
LOG(FATAL) << "abort";
#endif
}};
// The worker thread is detached here to avoid the need to handle it later during
// destruction. For C++, if a thread is not joined or detached, it will segfault during
// destruction.
error_worker_.detach();
proto::Start start;
@@ -251,7 +307,10 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
// get ring neighbors
std::string snext;
tracker.Recv(&snext);
rc = tracker.Recv(&snext);
if (!rc.OK()) {
return Fail("Failed to receive the rank for the next worker.", std::move(rc));
}
auto jnext = Json::Load(StringView{snext});
proto::PeerInfo ninfo{jnext};
@@ -268,14 +327,21 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
CHECK(this->channels_.empty());
for (auto& w : workers) {
if (w) {
rc = std::move(rc) << [&] { return w->SetNoDelay(); } << [&] { return w->NonBlocking(true); }
<< [&] { return w->SetKeepAlive(); };
rc = std::move(rc) << [&] {
return w->SetNoDelay();
} << [&] {
return w->NonBlocking(true);
} << [&] {
return w->SetKeepAlive();
};
}
if (!rc.OK()) {
return rc;
}
this->channels_.emplace_back(std::make_shared<Channel>(*this, w));
}
LOG(CONSOLE) << InitLog(task_id_, rank_);
return rc;
}
@@ -283,6 +349,8 @@ RabitComm::~RabitComm() noexcept(false) {
if (!this->IsDistributed()) {
return;
}
LOG(WARNING) << "The communicator is being destroyed without a call to shutdown first. This can "
"lead to undefined behaviour.";
auto rc = this->Shutdown();
if (!rc.OK()) {
LOG(WARNING) << rc.Report();
@@ -293,30 +361,49 @@ RabitComm::~RabitComm() noexcept(false) {
if (!this->IsDistributed()) {
return Success();
}
// Tell the tracker that this worker is shutting down.
TCPSocket tracker;
// Tell the error hanlding thread that we are shutting down.
TCPSocket err_client;
return Success() << [&] {
return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World());
} << [&] {
return this->Block();
} << [&] {
Json jcmd{Object{}};
jcmd["cmd"] = Integer{static_cast<std::int32_t>(proto::CMD::kShutdown)};
auto scmd = Json::Dump(jcmd);
auto n_bytes = tracker.Send(scmd);
if (n_bytes != scmd.size()) {
return Fail("Faled to send cmd.");
}
this->ResetState();
return Success();
return proto::ShutdownCMD{}.Send(&tracker);
} << [&] {
this->channels_.clear();
return Success();
} << [&] {
// Use tracker address to determine whether we want to use IPv6.
auto taddr = MakeSockAddress(xgboost::StringView{this->tracker_.host}, this->tracker_.port);
// Shutdown the error handling thread. We signal the thread through socket,
// alternatively, we can get the native handle and use pthread_cancel. But using a
// socket seems to be clearer as we know what's happening.
auto const& addr = taddr.IsV4() ? SockAddrV4::Loopback().Addr() : SockAddrV6::Loopback().Addr();
// We use hardcoded 10 seconds and 1 retry here since we are just connecting to a
// local socket. For a normal OS, this should be enough time to schedule the
// connection.
auto rc = Connect(StringView{addr}, this->error_port_, 1,
std::min(std::chrono::seconds{10}, timeout_), &err_client);
this->ResetState();
if (!rc.OK()) {
return Fail("Failed to connect to the error socket.", std::move(rc));
}
return rc;
} << [&] {
// We put error thread shutdown at the end so that we have a better chance to finish
// the previous more important steps.
return proto::Error{}.SignalShutdown(&err_client);
};
}
[[nodiscard]] Result RabitComm::LogTracker(std::string msg) const {
if (!this->IsDistributed()) {
LOG(CONSOLE) << msg;
return Success();
}
TCPSocket out;
proto::Print print;
return Success() << [&] { return this->ConnectTracker(&out); }
@@ -324,8 +411,11 @@ RabitComm::~RabitComm() noexcept(false) {
}
[[nodiscard]] Result RabitComm::SignalError(Result const& res) {
TCPSocket out;
return Success() << [&] { return this->ConnectTracker(&out); }
<< [&] { return proto::ErrorCMD{}.WorkerSend(&out, res); };
TCPSocket tracker;
return Success() << [&] {
return this->ConnectTracker(&tracker);
} << [&] {
return proto::ErrorCMD{}.WorkerSend(&tracker, res);
};
}
} // namespace xgboost::collective

View File

@@ -1,10 +1,10 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
#include <chrono> // for seconds
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <cstdint> // for int32_t, int64_t
#include <memory> // for shared_ptr
#include <string> // for string
#include <thread> // for thread
@@ -20,7 +20,7 @@
namespace xgboost::collective {
inline constexpr std::int32_t DefaultTimeoutSec() { return 300; } // 5min
inline constexpr std::int64_t DefaultTimeoutSec() { return 300; } // 5min
inline constexpr std::int32_t DefaultRetry() { return 3; }
// indexing into the ring
@@ -51,7 +51,10 @@ class Comm : public std::enable_shared_from_this<Comm> {
proto::PeerInfo tracker_;
SockDomain domain_{SockDomain::kV4};
std::thread error_worker_;
std::int32_t error_port_;
std::string task_id_;
std::vector<std::shared_ptr<Channel>> channels_;
std::shared_ptr<Loop> loop_{nullptr}; // fixme: require federated comm to have a timeout
@@ -59,6 +62,13 @@ class Comm : public std::enable_shared_from_this<Comm> {
void ResetState() {
this->world_ = -1;
this->rank_ = 0;
this->timeout_ = std::chrono::seconds{DefaultTimeoutSec()};
tracker_ = proto::PeerInfo{};
this->task_id_.clear();
channels_.clear();
loop_.reset();
}
public:
@@ -79,9 +89,9 @@ class Comm : public std::enable_shared_from_this<Comm> {
[[nodiscard]] auto Retry() const { return retry_; }
[[nodiscard]] auto TaskID() const { return task_id_; }
[[nodiscard]] auto Rank() const { return rank_; }
[[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; }
[[nodiscard]] bool IsDistributed() const { return world_ != -1; }
[[nodiscard]] auto Rank() const noexcept { return rank_; }
[[nodiscard]] auto World() const noexcept { return IsDistributed() ? world_ : 1; }
[[nodiscard]] bool IsDistributed() const noexcept { return world_ != -1; }
void Submit(Loop::Op op) const {
CHECK(loop_);
loop_->Submit(op);
@@ -120,20 +130,20 @@ class RabitComm : public HostComm {
[[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
std::string task_id);
[[nodiscard]] Result Shutdown() final;
public:
// bootstrapping construction.
RabitComm() = default;
// ctor for testing where environment is known.
RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t retry, std::string task_id, StringView nccl_path);
RabitComm(std::string const& tracker_host, std::int32_t tracker_port,
std::chrono::seconds timeout, std::int32_t retry, std::string task_id,
StringView nccl_path);
~RabitComm() noexcept(false) override;
[[nodiscard]] bool IsFederated() const override { return false; }
[[nodiscard]] Result LogTracker(std::string msg) const override;
[[nodiscard]] Result SignalError(Result const&) override;
[[nodiscard]] Result Shutdown() final;
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
};

View File

@@ -64,6 +64,9 @@ CommGroup::CommGroup()
auto const& obj = get<Object const>(config);
auto it = obj.find(upper);
if (it != obj.cend() && obj.find(name) != obj.cend()) {
LOG(FATAL) << "Duplicated parameter:" << name;
}
if (it != obj.cend()) {
return OptionalArg<decltype(t)>(config, upper, dft);
} else {
@@ -77,14 +80,14 @@ CommGroup::CommGroup()
auto task_id = get_param("dmlc_task_id", std::string{}, String{});
if (type == "rabit") {
auto host = get_param("dmlc_tracker_uri", std::string{}, String{});
auto port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{});
auto tracker_host = get_param("dmlc_tracker_uri", std::string{}, String{});
auto tracker_port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{});
auto nccl = get_param("dmlc_nccl_path", std::string{DefaultNcclName()}, String{});
auto ptr =
new CommGroup{std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT
host, static_cast<std::int32_t>(port), std::chrono::seconds{timeout},
static_cast<std::int32_t>(retry), task_id, nccl}},
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
auto ptr = new CommGroup{
std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT
tracker_host, static_cast<std::int32_t>(tracker_port), std::chrono::seconds{timeout},
static_cast<std::int32_t>(retry), task_id, nccl}},
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
return ptr;
} else if (type == "federated") {
#if defined(XGBOOST_USE_FEDERATED)

View File

@@ -30,9 +30,9 @@ class CommGroup {
public:
CommGroup();
[[nodiscard]] auto World() const { return comm_->World(); }
[[nodiscard]] auto Rank() const { return comm_->Rank(); }
[[nodiscard]] bool IsDistributed() const { return comm_->IsDistributed(); }
[[nodiscard]] auto World() const noexcept { return comm_->World(); }
[[nodiscard]] auto Rank() const noexcept { return comm_->Rank(); }
[[nodiscard]] bool IsDistributed() const noexcept { return comm_->IsDistributed(); }
[[nodiscard]] Result Finalize() const {
return Success() << [this] {

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int32_t
@@ -58,6 +58,7 @@ struct Magic {
}
};
// Basic commands for communication between workers and the tracker.
enum class CMD : std::int32_t {
kInvalid = 0,
kStart = 1,
@@ -84,7 +85,10 @@ struct Connect {
[[nodiscard]] Result TrackerRecv(TCPSocket* sock, std::int32_t* world, std::int32_t* rank,
std::string* task_id) const {
std::string init;
sock->Recv(&init);
auto rc = sock->Recv(&init);
if (!rc.OK()) {
return Fail("Connect protocol failed.", std::move(rc));
}
auto jinit = Json::Load(StringView{init});
*world = get<Integer const>(jinit["world_size"]);
*rank = get<Integer const>(jinit["rank"]);
@@ -122,9 +126,9 @@ class Start {
}
[[nodiscard]] Result WorkerRecv(TCPSocket* tracker, std::int32_t* p_world) const {
std::string scmd;
auto n_bytes = tracker->Recv(&scmd);
if (n_bytes <= 0) {
return Fail("Failed to recv init command from tracker.");
auto rc = tracker->Recv(&scmd);
if (!rc.OK()) {
return Fail("Failed to recv init command from tracker.", std::move(rc));
}
auto jcmd = Json::Load(scmd);
auto world = get<Integer const>(jcmd["world_size"]);
@@ -132,7 +136,7 @@ class Start {
return Fail("Invalid world size.");
}
*p_world = world;
return Success();
return rc;
}
[[nodiscard]] Result TrackerHandle(Json jcmd, std::int32_t* recv_world, std::int32_t world,
std::int32_t* p_port, TCPSocket* p_sock,
@@ -150,6 +154,7 @@ class Start {
}
};
// Protocol for communicating with the tracker for printing message.
struct Print {
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::string msg) const {
Json jcmd{Object{}};
@@ -172,6 +177,7 @@ struct Print {
}
};
// Protocol for communicating with the tracker during error.
struct ErrorCMD {
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, Result const& res) const {
auto msg = res.Report();
@@ -199,6 +205,7 @@ struct ErrorCMD {
}
};
// Protocol for communicating with the tracker during shutdown.
struct ShutdownCMD {
[[nodiscard]] Result Send(TCPSocket* peer) const {
Json jcmd{Object{}};
@@ -211,4 +218,40 @@ struct ShutdownCMD {
return Success();
}
};
// Protocol for communicating with the local error handler during error or shutdown. Only
// one protocol that doesn't have the tracker involved.
struct Error {
constexpr static std::int32_t ShutdownSignal() { return 0; }
constexpr static std::int32_t ErrorSignal() { return -1; }
[[nodiscard]] Result SignalError(TCPSocket* worker) const {
std::int32_t err{ErrorSignal()};
auto n_sent = worker->SendAll(&err, sizeof(err));
if (n_sent == sizeof(err)) {
return Success();
}
return Fail("Failed to send error signal");
}
// self is localhost, we are sending the signal to the error handling thread for it to
// close.
[[nodiscard]] Result SignalShutdown(TCPSocket* self) const {
std::int32_t err{ShutdownSignal()};
auto n_sent = self->SendAll(&err, sizeof(err));
if (n_sent == sizeof(err)) {
return Success();
}
return Fail("Failed to send shutdown signal");
}
// get signal, either for error or for shutdown.
[[nodiscard]] Result RecvSignal(TCPSocket* peer, bool* p_is_error) const {
std::int32_t err{ShutdownSignal()};
auto n_recv = peer->RecvAll(&err, sizeof(err));
if (n_recv == sizeof(err)) {
*p_is_error = err == 1;
return Success();
}
return Fail("Failed to receive error signal.");
}
};
} // namespace xgboost::collective::proto

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2022-2023 by XGBoost Contributors
* Copyright 2022-2024, XGBoost Contributors
*/
#include "xgboost/collective/socket.h"
@@ -8,7 +8,8 @@
#include <cstdint> // std::int32_t
#include <cstring> // std::memcpy, std::memset
#include <filesystem> // for path
#include <system_error> // std::error_code, std::system_category
#include <system_error> // for error_code, system_category
#include <thread> // for sleep_for
#include "rabit/internal/socket.h" // for PollHelper
#include "xgboost/collective/result.h" // for Result
@@ -65,14 +66,18 @@ std::size_t TCPSocket::Send(StringView str) {
return bytes;
}
std::size_t TCPSocket::Recv(std::string *p_str) {
[[nodiscard]] Result TCPSocket::Recv(std::string *p_str) {
CHECK(!this->IsClosed());
std::int32_t len;
CHECK_EQ(this->RecvAll(&len, sizeof(len)), sizeof(len)) << "Failed to recv string length.";
if (this->RecvAll(&len, sizeof(len)) != sizeof(len)) {
return Fail("Failed to recv string length.");
}
p_str->resize(len);
auto bytes = this->RecvAll(&(*p_str)[0], len);
CHECK_EQ(bytes, len) << "Failed to recv string.";
return bytes;
if (static_cast<decltype(len)>(bytes) != len) {
return Fail("Failed to recv string.");
}
return Success();
}
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
@@ -110,11 +115,7 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
if (attempt > 0) {
LOG(WARNING) << "Retrying connection to " << host << " for the " << attempt << " time.";
#if defined(_MSC_VER) || defined(__MINGW32__)
Sleep(attempt << 1);
#else
sleep(attempt << 1);
#endif
std::this_thread::sleep_for(std::chrono::seconds{attempt << 1});
}
auto rc = connect(conn.Handle(), addr_handle, addr_len);
@@ -158,8 +159,8 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
std::stringstream ss;
ss << "Failed to connect to " << host << ":" << port;
conn.Close();
return Fail(ss.str(), std::move(last_error));
auto close_rc = conn.Close();
return Fail(ss.str(), std::move(close_rc) + std::move(last_error));
}
[[nodiscard]] Result GetHostName(std::string *p_out) {

View File

@@ -1,6 +1,7 @@
/**
* Copyright 2023-2024, XGBoost Contributors
*/
#include "rabit/internal/socket.h"
#if defined(__unix__) || defined(__APPLE__)
#include <netdb.h> // gethostbyname
#include <sys/socket.h> // socket, AF_INET6, AF_INET, connect, getsockname
@@ -70,10 +71,13 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA
return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_);
} << [&] {
std::string cmd;
sock_.Recv(&cmd);
auto rc = sock_.Recv(&cmd);
if (!rc.OK()) {
return rc;
}
jcmd = Json::Load(StringView{cmd});
cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"]));
return Success();
return rc;
} << [&] {
if (cmd_ == proto::CMD::kStart) {
proto::Start start;
@@ -100,14 +104,18 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA
RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
std::string self;
auto rc = collective::GetHostAddress(&self);
host_ = OptionalArg<String>(config, "host", self);
auto rc = Success() << [&] {
return collective::GetHostAddress(&self);
} << [&] {
host_ = OptionalArg<String>(config, "host", self);
auto addr = MakeSockAddress(xgboost::StringView{host_}, 0);
listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6);
rc = listener_.Bind(host_, &this->port_);
auto addr = MakeSockAddress(xgboost::StringView{host_}, 0);
listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6);
return listener_.Bind(host_, &this->port_);
} << [&] {
return listener_.Listen();
};
SafeColl(rc);
listener_.Listen();
}
Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
@@ -220,9 +228,13 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
//
// retry is set to 1, just let the worker timeout or error. Otherwise the
// tracker and the worker might be waiting for each other.
auto rc = Connect(w.first, w.second, 1, timeout_, &out);
auto rc = Success() << [&] {
return Connect(w.first, w.second, 1, timeout_, &out);
} << [&] {
return proto::Error{}.SignalError(&out);
};
if (!rc.OK()) {
return Fail("Failed to inform workers to stop.");
return Fail("Failed to inform worker:" + w.first + " for error.", std::move(rc));
}
}
return Success();
@@ -231,13 +243,37 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
return std::async(std::launch::async, [this, handle_error] {
State state{this->n_workers_};
auto select_accept = [&](TCPSocket* sock, auto* addr) {
// accept with poll so that we can enable timeout and interruption.
rabit::utils::PollHelper poll;
auto rc = Success() << [&] {
std::lock_guard lock{listener_mu_};
return listener_.NonBlocking(true);
} << [&] {
std::lock_guard lock{listener_mu_};
poll.WatchRead(listener_);
if (state.running) {
// Don't timeout if the communicator group is up and running.
return poll.Poll(std::chrono::seconds{-1});
} else {
// Have timeout for workers to bootstrap.
return poll.Poll(timeout_);
}
} << [&] {
// this->Stop() closes the socket with a lock. Therefore, when the accept returns
// due to shutdown, the state is still valid (closed).
return listener_.Accept(sock, addr);
};
return rc;
};
while (state.ShouldContinue()) {
TCPSocket sock;
SockAddress addr;
this->ready_ = true;
auto rc = listener_.Accept(&sock, &addr);
auto rc = select_accept(&sock, &addr);
if (!rc.OK()) {
return Fail("Failed to accept connection.", std::move(rc));
return Fail("Failed to accept connection.", this->Stop() + std::move(rc));
}
auto worker = WorkerProxy{n_workers_, std::move(sock), std::move(addr)};
@@ -252,7 +288,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
state.Error();
rc = handle_error(worker);
if (!rc.OK()) {
return Fail("Failed to handle abort.", std::move(rc));
return Fail("Failed to handle abort.", this->Stop() + std::move(rc));
}
}
@@ -262,7 +298,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
state.Bootstrap();
}
if (!rc.OK()) {
return rc;
return this->Stop() + std::move(rc);
}
continue;
}
@@ -289,12 +325,11 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
}
case proto::CMD::kInvalid:
default: {
return Fail("Invalid command received.");
return Fail("Invalid command received.", this->Stop());
}
}
}
ready_ = false;
return Success();
return this->Stop();
});
}
@@ -303,11 +338,30 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
SafeColl(rc);
Json args{Object{}};
args["DMLC_TRACKER_URI"] = String{host_};
args["DMLC_TRACKER_PORT"] = this->Port();
args["dmlc_tracker_uri"] = String{host_};
args["dmlc_tracker_port"] = this->Port();
return args;
}
[[nodiscard]] Result RabitTracker::Stop() {
if (!this->Ready()) {
return Success();
}
ready_ = false;
std::lock_guard lock{listener_mu_};
if (this->listener_.IsClosed()) {
return Success();
}
return Success() << [&] {
// This should have the effect of stopping the `accept` call.
return this->listener_.Shutdown();
} << [&] {
return listener_.Close();
};
}
[[nodiscard]] Result GetHostAddress(std::string* out) {
auto rc = GetHostName(out);
if (!rc.OK()) {

View File

@@ -36,15 +36,18 @@ namespace xgboost::collective {
* signal an error to the tracker and the tracker will notify other workers.
*/
class Tracker {
public:
enum class SortBy : std::int8_t {
kHost = 0,
kTask = 1,
};
protected:
// How to sort the workers, either by host name or by task ID. When using a multi-GPU
// setting, multiple workers can occupy the same host, in which case one should sort
// workers by task. Due to compatibility reason, the task ID is not always available, so
// we use host as the default.
enum class SortBy : std::int8_t {
kHost = 0,
kTask = 1,
} sortby_;
SortBy sortby_;
protected:
std::int32_t n_workers_{0};
@@ -54,10 +57,7 @@ class Tracker {
public:
explicit Tracker(Json const& config);
Tracker(std::int32_t n_worders, std::int32_t port, std::chrono::seconds timeout)
: n_workers_{n_worders}, port_{port}, timeout_{timeout} {}
virtual ~Tracker() noexcept(false){}; // NOLINT
virtual ~Tracker() = default;
[[nodiscard]] Result WaitUntilReady() const;
@@ -69,6 +69,11 @@ class Tracker {
* @brief Flag to indicate whether the server is running.
*/
[[nodiscard]] bool Ready() const { return ready_; }
/**
* @brief Shutdown the tracker, cannot be restarted again. Useful when the tracker hangs while
* calling accept.
*/
virtual Result Stop() { return Success(); }
};
class RabitTracker : public Tracker {
@@ -127,28 +132,22 @@ class RabitTracker : public Tracker {
// record for how to reach out to workers if error happens.
std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_;
// listening socket for incoming workers.
//
// At the moment, the listener calls accept without first polling. We can add an
// additional unix domain socket to allow cancelling the accept.
TCPSocket listener_;
// mutex for protecting the listener, used to prevent race when it's listening while
// another thread tries to shut it down.
std::mutex listener_mu_;
Result Bootstrap(std::vector<WorkerProxy>* p_workers);
public:
explicit RabitTracker(StringView host, std::int32_t n_worders, std::int32_t port,
std::chrono::seconds timeout)
: Tracker{n_worders, port, timeout}, host_{host.c_str(), host.size()} {
listener_ = TCPSocket::Create(SockDomain::kV4);
auto rc = listener_.Bind(host, &this->port_);
CHECK(rc.OK()) << rc.Report();
listener_.Listen();
}
explicit RabitTracker(Json const& config);
~RabitTracker() noexcept(false) override = default;
~RabitTracker() override = default;
std::future<Result> Run() override;
[[nodiscard]] Json WorkerArgs() const override;
// Stop the tracker without waiting. This is to prevent the tracker from hanging when
// one of the workers failes to start.
[[nodiscard]] Result Stop() override;
};
// Prob the public IP address of the host, need a better method.