Jiaming Yuan a5a58102e5
Revamp the rabit implementation. (#10112)
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
2024-05-20 11:56:23 +08:00

428 lines
14 KiB
C++

/**
* Copyright 2023-2024, XGBoost Contributors
*/
#include "comm.h"
#include <algorithm> // for copy
#include <chrono> // for seconds
#include <cstdint> // for int32_t
#include <cstdlib> // for exit
#include <memory> // for shared_ptr
#include <string> // for string
#include <thread> // for thread
#include <utility> // for move, forward
#if !defined(XGBOOST_USE_NCCL)
#include "../common/common.h" // for AssertNCCLSupport
#endif // !defined(XGBOOST_USE_NCCL)
#include "allgather.h" // for RingAllgather
#include "protocol.h" // for kMagic
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
#include "xgboost/collective/socket.h" // for TCPSocket
#include "xgboost/json.h" // for Json, Object
#include "xgboost/string_view.h" // for StringView
namespace xgboost::collective {
Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t retry, std::string task_id)
: timeout_{timeout}, retry_{retry}, tracker_{host, port, -1}, task_id_{std::move(task_id)} {}
Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
std::string const& task_id, TCPSocket* out, std::int32_t rank,
std::int32_t world) {
// Get information from the tracker
CHECK(!info.host.empty());
TCPSocket& tracker = *out;
return Success() << [&] {
auto rc = Connect(info.host, info.port, retry, timeout, out);
if (rc.OK()) {
return rc;
} else {
return Fail("Failed to connect to the tracker.", std::move(rc));
}
} << [&] {
return tracker.NonBlocking(false);
} << [&] {
return tracker.RecvTimeout(timeout);
} << [&] {
return proto::Magic{}.Verify(&tracker);
} << [&] {
return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id);
} << [&] {
LOG(INFO) << "Task " << task_id << " connected to the tracker";
return Success();
};
}
[[nodiscard]] Result Comm::ConnectTracker(TCPSocket* out) const {
return ConnectTrackerImpl(this->TrackerInfo(), this->Timeout(), this->retry_, this->task_id_, out,
this->Rank(), this->World());
}
[[nodiscard]] Result ConnectWorkers(Comm const& comm, TCPSocket* listener, std::int32_t lport,
proto::PeerInfo ninfo, std::chrono::seconds timeout,
std::int32_t retry,
std::vector<std::shared_ptr<TCPSocket>>* out_workers) {
auto next = std::make_shared<TCPSocket>();
auto prev = std::make_shared<TCPSocket>();
auto rc = Success() << [&] {
auto rc = Connect(ninfo.host, ninfo.port, retry, timeout, next.get());
if (!rc.OK()) {
return Fail("Bootstrap failed to connect to ring next.", std::move(rc));
}
return rc;
} << [&] {
return next->NonBlocking(true);
} << [&] {
SockAddress addr;
return listener->Accept(prev.get(), &addr);
} << [&] {
return prev->NonBlocking(true);
};
if (!rc.OK()) {
return rc;
}
// exchange host name and port
std::vector<std::int8_t> buffer(HOST_NAME_MAX * comm.World(), 0);
auto s_buffer = common::Span{buffer.data(), buffer.size()};
auto next_host = s_buffer.subspan(HOST_NAME_MAX * comm.Rank(), HOST_NAME_MAX);
if (next_host.size() < ninfo.host.size()) {
return Fail("Got an invalid host name.");
}
std::copy(ninfo.host.cbegin(), ninfo.host.cend(), next_host.begin());
auto prev_ch = std::make_shared<Channel>(comm, prev);
auto next_ch = std::make_shared<Channel>(comm, next);
auto block = [&] {
for (auto ch : {prev_ch, next_ch}) {
auto rc = ch->Block();
if (!rc.OK()) {
return rc;
}
}
return Success();
};
rc = std::move(rc) << [&] {
return cpu_impl::RingAllgather(comm, s_buffer, HOST_NAME_MAX, 0, prev_ch, next_ch);
} << [&] { return block(); };
if (!rc.OK()) {
return Fail("Failed to get host names from peers.", std::move(rc));
}
std::vector<std::int32_t> peers_port(comm.World(), -1);
peers_port[comm.Rank()] = ninfo.port;
rc = std::move(rc) << [&] {
auto s_ports = common::Span{reinterpret_cast<std::int8_t*>(peers_port.data()),
peers_port.size() * sizeof(ninfo.port)};
return cpu_impl::RingAllgather(comm, s_ports, sizeof(ninfo.port), 0, prev_ch, next_ch);
} << [&] { return block(); };
if (!rc.OK()) {
return Fail("Failed to get the port from peers.", std::move(rc));
}
std::vector<proto::PeerInfo> peers(comm.World());
for (auto r = 0; r < comm.World(); ++r) {
auto nhost = s_buffer.subspan(HOST_NAME_MAX * r, HOST_NAME_MAX);
auto nport = peers_port[r];
auto nrank = BootstrapNext(r, comm.World());
peers[nrank] = {std::string{reinterpret_cast<char const*>(nhost.data())}, nport, nrank};
}
CHECK_EQ(peers[comm.Rank()].port, lport);
for (auto const& p : peers) {
CHECK_NE(p.port, -1);
}
std::vector<std::shared_ptr<TCPSocket>>& workers = *out_workers;
workers.resize(comm.World());
for (std::int32_t r = (comm.Rank() + 1); r < comm.World(); ++r) {
auto const& peer = peers[r];
std::shared_ptr<TCPSocket> worker{TCPSocket::CreatePtr(comm.Domain())};
rc = std::move(rc)
<< [&] { return Connect(peer.host, peer.port, retry, timeout, worker.get()); }
<< [&] { return worker->RecvTimeout(timeout); };
if (!rc.OK()) {
return rc;
}
auto rank = comm.Rank();
std::size_t n_bytes{0};
auto rc = worker->SendAll(&rank, sizeof(comm.Rank()), &n_bytes);
if (!rc.OK()) {
return rc;
} else if (n_bytes != sizeof(comm.Rank())) {
return Fail("Failed to send rank.", std::move(rc));
}
workers[r] = std::move(worker);
}
for (std::int32_t r = 0; r < comm.Rank(); ++r) {
auto peer = std::shared_ptr<TCPSocket>(TCPSocket::CreatePtr(comm.Domain()));
rc = std::move(rc) << [&] {
SockAddress addr;
return listener->Accept(peer.get(), &addr);
} << [&] {
return peer->RecvTimeout(timeout);
};
if (!rc.OK()) {
return rc;
}
std::int32_t rank{-1};
std::size_t n_bytes{0};
auto rc = peer->RecvAll(&rank, sizeof(rank), &n_bytes);
if (!rc.OK()) {
return rc;
} else if (n_bytes != sizeof(comm.Rank())) {
return Fail("Failed to recv rank.");
}
workers[rank] = std::move(peer);
}
for (std::int32_t r = 0; r < comm.World(); ++r) {
if (r == comm.Rank()) {
continue;
}
CHECK(workers[r]);
}
return Success();
}
namespace {
std::string InitLog(std::string task_id, std::int32_t rank) {
if (task_id.empty()) {
return "Rank " + std::to_string(rank);
}
return "Task " + task_id + " got rank " + std::to_string(rank);
}
} // namespace
RabitComm::RabitComm(std::string const& tracker_host, std::int32_t tracker_port,
std::chrono::seconds timeout, std::int32_t retry, std::string task_id,
StringView nccl_path)
: HostComm{tracker_host, tracker_port, timeout, retry, std::move(task_id)},
nccl_path_{std::move(nccl_path)} {
if (this->TrackerInfo().host.empty()) {
// Not in a distributed environment.
LOG(CONSOLE) << InitLog(task_id_, rank_);
return;
}
loop_.reset(new Loop{std::chrono::seconds{timeout_}}); // NOLINT
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
if (!rc.OK()) {
this->ResetState();
SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc)));
}
}
#if !defined(XGBOOST_USE_NCCL)
Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
common::AssertGPUSupport();
common::AssertNCCLSupport();
return nullptr;
}
#endif // !defined(XGBOOST_USE_NCCL)
[[nodiscard]] Result RabitComm::Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
std::string task_id) {
TCPSocket tracker;
std::int32_t world{-1};
auto rc = ConnectTrackerImpl(this->TrackerInfo(), timeout, retry, task_id, &tracker, this->Rank(),
world);
if (!rc.OK()) {
return Fail("Bootstrap failed.", std::move(rc));
}
this->domain_ = tracker.Domain();
// Start command
TCPSocket listener = TCPSocket::Create(tracker.Domain());
std::int32_t lport{0};
rc = std::move(rc) << [&] {
return listener.BindHost(&lport);
} << [&] {
return listener.Listen();
};
if (!rc.OK()) {
return rc;
}
// create worker for listening to error notice.
auto domain = tracker.Domain();
std::shared_ptr<TCPSocket> error_sock{TCPSocket::CreatePtr(domain)};
std::int32_t eport{0};
rc = std::move(rc) << [&] {
return error_sock->BindHost(&eport);
} << [&] {
return error_sock->Listen();
};
if (!rc.OK()) {
return rc;
}
error_port_ = eport;
error_worker_ = std::thread{[error_sock = std::move(error_sock)] {
TCPSocket conn;
SockAddress addr;
auto rc = error_sock->Accept(&conn, &addr);
// On Linux, a shutdown causes an invalid argument error;
if (rc.Code() == std::errc::invalid_argument) {
return;
}
// On Windows, accept returns a closed socket after finalize.
if (conn.IsClosed()) {
return;
}
// The error signal is from the tracker, while shutdown signal is from the shutdown method
// of the RabitComm class (this).
bool is_error{false};
rc = proto::Error{}.RecvSignal(&conn, &is_error);
if (!rc.OK()) {
LOG(WARNING) << rc.Report();
return;
}
if (!is_error) {
return; // shutdown
}
LOG(WARNING) << "Another worker is running into error.";
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
// exit is nicer than abort as the former performs cleanups.
std::exit(-1);
#else
LOG(FATAL) << "abort";
#endif
}};
// The worker thread is detached here to avoid the need to handle it later during
// destruction. For C++, if a thread is not joined or detached, it will segfault during
// destruction.
error_worker_.detach();
proto::Start start;
rc = std::move(rc) << [&] { return start.WorkerSend(lport, &tracker, eport); }
<< [&] { return start.WorkerRecv(&tracker, &world); };
if (!rc.OK()) {
return rc;
}
this->world_ = world;
// get ring neighbors
std::string snext;
rc = tracker.Recv(&snext);
if (!rc.OK()) {
return Fail("Failed to receive the rank for the next worker.", std::move(rc));
}
auto jnext = Json::Load(StringView{snext});
proto::PeerInfo ninfo{jnext};
// get the rank of this worker
this->rank_ = BootstrapPrev(ninfo.rank, world);
this->tracker_.rank = rank_;
std::vector<std::shared_ptr<TCPSocket>> workers;
rc = ConnectWorkers(*this, &listener, lport, ninfo, timeout, retry, &workers);
if (!rc.OK()) {
return Fail("Failed to connect to other workers.", std::move(rc));
}
CHECK(this->channels_.empty());
for (auto& w : workers) {
if (w) {
rc = std::move(rc) << [&] {
return w->SetNoDelay();
} << [&] {
return w->NonBlocking(true);
} << [&] {
return w->SetKeepAlive();
};
}
if (!rc.OK()) {
return rc;
}
this->channels_.emplace_back(std::make_shared<Channel>(*this, w));
}
LOG(CONSOLE) << InitLog(task_id_, rank_);
return rc;
}
RabitComm::~RabitComm() noexcept(false) {
if (!this->IsDistributed()) {
return;
}
LOG(WARNING) << "The communicator is being destroyed without a call to shutdown first. This can "
"lead to undefined behaviour.";
auto rc = this->Shutdown();
if (!rc.OK()) {
LOG(WARNING) << rc.Report();
}
}
[[nodiscard]] Result RabitComm::Shutdown() {
if (!this->IsDistributed()) {
return Success();
}
// Tell the tracker that this worker is shutting down.
TCPSocket tracker;
// Tell the error hanlding thread that we are shutting down.
TCPSocket err_client;
return Success() << [&] {
return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World());
} << [&] {
return this->Block();
} << [&] {
return proto::ShutdownCMD{}.Send(&tracker);
} << [&] {
this->channels_.clear();
return Success();
} << [&] {
// Use tracker address to determine whether we want to use IPv6.
auto taddr = MakeSockAddress(xgboost::StringView{this->tracker_.host}, this->tracker_.port);
// Shutdown the error handling thread. We signal the thread through socket,
// alternatively, we can get the native handle and use pthread_cancel. But using a
// socket seems to be clearer as we know what's happening.
auto const& addr = taddr.IsV4() ? SockAddrV4::Loopback().Addr() : SockAddrV6::Loopback().Addr();
// We use hardcoded 10 seconds and 1 retry here since we are just connecting to a
// local socket. For a normal OS, this should be enough time to schedule the
// connection.
auto rc = Connect(StringView{addr}, this->error_port_, 1,
std::min(std::chrono::seconds{10}, timeout_), &err_client);
this->ResetState();
if (!rc.OK()) {
return Fail("Failed to connect to the error socket.", std::move(rc));
}
return rc;
} << [&] {
// We put error thread shutdown at the end so that we have a better chance to finish
// the previous more important steps.
return proto::Error{}.SignalShutdown(&err_client);
};
}
[[nodiscard]] Result RabitComm::LogTracker(std::string msg) const {
if (!this->IsDistributed()) {
LOG(CONSOLE) << msg;
return Success();
}
TCPSocket out;
proto::Print print;
return Success() << [&] { return this->ConnectTracker(&out); }
<< [&] { return print.WorkerSend(&out, msg); };
}
[[nodiscard]] Result RabitComm::SignalError(Result const& res) {
TCPSocket tracker;
return Success() << [&] {
return this->ConnectTracker(&tracker);
} << [&] {
return proto::ErrorCMD{}.WorkerSend(&tracker, res);
};
}
} // namespace xgboost::collective