[coll] Implement a new tracker and a communicator. (#9650)
* [coll] Implement a new tracker and a communicator. The new tracker and communicators communicate through the use of JSON documents. Along with which, communicators are aware of each other.
This commit is contained in:
parent
2e42f33fc1
commit
946ae1c440
@ -98,9 +98,13 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/context.o \
|
||||
$(PKGROOT)/src/logging.o \
|
||||
$(PKGROOT)/src/global_config.o \
|
||||
$(PKGROOT)/src/collective/allgather.o \
|
||||
$(PKGROOT)/src/collective/comm.o \
|
||||
$(PKGROOT)/src/collective/tracker.o \
|
||||
$(PKGROOT)/src/collective/communicator.o \
|
||||
$(PKGROOT)/src/collective/in_memory_communicator.o \
|
||||
$(PKGROOT)/src/collective/in_memory_handler.o \
|
||||
$(PKGROOT)/src/collective/loop.o \
|
||||
$(PKGROOT)/src/collective/socket.o \
|
||||
$(PKGROOT)/src/common/charconv.o \
|
||||
$(PKGROOT)/src/common/column_matrix.o \
|
||||
|
||||
@ -98,9 +98,13 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/context.o \
|
||||
$(PKGROOT)/src/logging.o \
|
||||
$(PKGROOT)/src/global_config.o \
|
||||
$(PKGROOT)/src/collective/allgather.o \
|
||||
$(PKGROOT)/src/collective/comm.o \
|
||||
$(PKGROOT)/src/collective/tracker.o \
|
||||
$(PKGROOT)/src/collective/communicator.o \
|
||||
$(PKGROOT)/src/collective/in_memory_communicator.o \
|
||||
$(PKGROOT)/src/collective/in_memory_handler.o \
|
||||
$(PKGROOT)/src/collective/loop.o \
|
||||
$(PKGROOT)/src/collective/socket.o \
|
||||
$(PKGROOT)/src/common/charconv.o \
|
||||
$(PKGROOT)/src/common/column_matrix.o \
|
||||
|
||||
@ -157,4 +157,13 @@ struct Result {
|
||||
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev) {
|
||||
return Result{std::move(msg), std::move(errc), std::forward<Result>(prev)};
|
||||
}
|
||||
|
||||
// We don't have monad, a simple helper would do.
|
||||
template <typename Fn>
|
||||
Result operator<<(Result&& r, Fn&& fn) {
|
||||
if (!r.OK()) {
|
||||
return std::forward<Result>(r);
|
||||
}
|
||||
return fn();
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -380,11 +380,18 @@ class TCPSocket {
|
||||
}
|
||||
[[nodiscard]] bool NonBlocking() const { return non_blocking_; }
|
||||
[[nodiscard]] Result RecvTimeout(std::chrono::seconds timeout) {
|
||||
timeval tv;
|
||||
// https://stackoverflow.com/questions/2876024/linux-is-there-a-read-or-recv-from-socket-with-timeout
|
||||
#if defined(_WIN32)
|
||||
DWORD tv = timeout.count() * 1000;
|
||||
auto rc =
|
||||
setsockopt(Handle(), SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char *>(&tv), sizeof(tv));
|
||||
#else
|
||||
struct timeval tv;
|
||||
tv.tv_sec = timeout.count();
|
||||
tv.tv_usec = 0;
|
||||
auto rc = setsockopt(Handle(), SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char const *>(&tv),
|
||||
sizeof(tv));
|
||||
#endif
|
||||
if (rc != 0) {
|
||||
return system::FailWithCode("Failed to set timeout on recv.");
|
||||
}
|
||||
@ -425,7 +432,12 @@ class TCPSocket {
|
||||
*/
|
||||
TCPSocket Accept() {
|
||||
HandleT newfd = accept(Handle(), nullptr, nullptr);
|
||||
if (newfd == InvalidSocket()) {
|
||||
#if defined(_WIN32)
|
||||
auto interrupt = WSAEINTR;
|
||||
#else
|
||||
auto interrupt = EINTR;
|
||||
#endif
|
||||
if (newfd == InvalidSocket() && system::LastError() != interrupt) {
|
||||
system::ThrowAtError("accept");
|
||||
}
|
||||
TCPSocket newsock{newfd};
|
||||
@ -468,7 +480,7 @@ class TCPSocket {
|
||||
/**
|
||||
* \brief Bind socket to INADDR_ANY, return the port selected by the OS.
|
||||
*/
|
||||
in_port_t BindHost() {
|
||||
[[nodiscard]] in_port_t BindHost() {
|
||||
if (Domain() == SockDomain::kV6) {
|
||||
auto addr = SockAddrV6::InaddrAny();
|
||||
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
|
||||
@ -539,7 +551,7 @@ class TCPSocket {
|
||||
/**
|
||||
* \brief Send data, without error then all data should be sent.
|
||||
*/
|
||||
auto SendAll(void const *buf, std::size_t len) {
|
||||
[[nodiscard]] auto SendAll(void const *buf, std::size_t len) {
|
||||
char const *_buf = reinterpret_cast<const char *>(buf);
|
||||
std::size_t ndone = 0;
|
||||
while (ndone < len) {
|
||||
@ -558,7 +570,7 @@ class TCPSocket {
|
||||
/**
|
||||
* \brief Receive data, without error then all data should be received.
|
||||
*/
|
||||
auto RecvAll(void *buf, std::size_t len) {
|
||||
[[nodiscard]] auto RecvAll(void *buf, std::size_t len) {
|
||||
char *_buf = reinterpret_cast<char *>(buf);
|
||||
std::size_t ndone = 0;
|
||||
while (ndone < len) {
|
||||
@ -612,7 +624,15 @@ class TCPSocket {
|
||||
*/
|
||||
void Close() {
|
||||
if (InvalidSocket() != handle_) {
|
||||
#if defined(_WIN32)
|
||||
auto rc = system::CloseSocket(handle_);
|
||||
// it's possible that we close TCP sockets after finalizing WSA due to detached thread.
|
||||
if (rc != 0 && system::LastError() != WSANOTINITIALISED) {
|
||||
system::ThrowAtError("close", rc);
|
||||
}
|
||||
#else
|
||||
xgboost_CHECK_SYS_CALL(system::CloseSocket(handle_), 0);
|
||||
#endif
|
||||
handle_ = InvalidSocket();
|
||||
}
|
||||
}
|
||||
@ -634,6 +654,24 @@ class TCPSocket {
|
||||
socket.domain_ = domain;
|
||||
#endif // defined(__APPLE__)
|
||||
return socket;
|
||||
#endif // defined(xgboost_IS_MINGW)
|
||||
}
|
||||
|
||||
static TCPSocket *CreatePtr(SockDomain domain) {
|
||||
#if defined(xgboost_IS_MINGW)
|
||||
MingWError();
|
||||
return nullptr;
|
||||
#else
|
||||
auto fd = socket(static_cast<std::int32_t>(domain), SOCK_STREAM, 0);
|
||||
if (fd == InvalidSocket()) {
|
||||
system::ThrowAtError("socket");
|
||||
}
|
||||
auto socket = new TCPSocket{fd};
|
||||
|
||||
#if defined(__APPLE__)
|
||||
socket->domain_ = domain;
|
||||
#endif // defined(__APPLE__)
|
||||
return socket;
|
||||
#endif // defined(xgboost_IS_MINGW)
|
||||
}
|
||||
};
|
||||
|
||||
42
src/collective/allgather.cc
Normal file
42
src/collective/allgather.cc
Normal file
@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include "allgather.h"
|
||||
|
||||
#include <algorithm> // for min
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int8_t
|
||||
#include <memory> // for shared_ptr
|
||||
|
||||
#include "comm.h" // for Comm, Channel
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective::cpu_impl {
|
||||
Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size_t segment_size,
|
||||
std::int32_t worker_off, std::shared_ptr<Channel> prev_ch,
|
||||
std::shared_ptr<Channel> next_ch) {
|
||||
auto world = comm.World();
|
||||
auto rank = comm.Rank();
|
||||
CHECK_LT(worker_off, world);
|
||||
|
||||
for (std::int32_t r = 0; r < world; ++r) {
|
||||
auto send_rank = (rank + world - r + worker_off) % world;
|
||||
auto send_off = send_rank * segment_size;
|
||||
send_off = std::min(send_off, data.size_bytes());
|
||||
auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off));
|
||||
next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
|
||||
|
||||
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
|
||||
auto recv_off = recv_rank * segment_size;
|
||||
recv_off = std::min(recv_off, data.size_bytes());
|
||||
auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off));
|
||||
prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||
auto rc = prev_ch->Block();
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
|
||||
return Success();
|
||||
}
|
||||
} // namespace xgboost::collective::cpu_impl
|
||||
23
src/collective/allgather.h
Normal file
23
src/collective/allgather.h
Normal file
@ -0,0 +1,23 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t
|
||||
#include <memory> // for shared_ptr
|
||||
|
||||
#include "comm.h" // for Comm, Channel
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace cpu_impl {
|
||||
/**
|
||||
* @param worker_off Segment offset. For example, if the rank 2 worker specifis worker_off
|
||||
* = 1, then it owns the third segment.
|
||||
*/
|
||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::size_t segment_size, std::int32_t worker_off,
|
||||
std::shared_ptr<Channel> prev_ch,
|
||||
std::shared_ptr<Channel> next_ch);
|
||||
} // namespace cpu_impl
|
||||
} // namespace xgboost::collective
|
||||
302
src/collective/comm.cc
Normal file
302
src/collective/comm.cc
Normal file
@ -0,0 +1,302 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include "comm.h"
|
||||
|
||||
#include <algorithm> // for copy
|
||||
#include <chrono> // for seconds
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string> // for string
|
||||
#include <utility> // for move, forward
|
||||
|
||||
#include "allgather.h"
|
||||
#include "protocol.h" // for kMagic
|
||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
||||
#include "xgboost/json.h" // for Json, Object
|
||||
|
||||
namespace xgboost::collective {
|
||||
Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t retry, std::string task_id)
|
||||
: timeout_{timeout},
|
||||
retry_{retry},
|
||||
tracker_{host, port, -1},
|
||||
task_id_{std::move(task_id)},
|
||||
loop_{std::make_shared<Loop>(timeout)} {}
|
||||
|
||||
Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
|
||||
std::string const& task_id, TCPSocket* out, std::int32_t rank,
|
||||
std::int32_t world) {
|
||||
// get information from tracker
|
||||
CHECK(!info.host.empty());
|
||||
auto rc = Connect(info.host, info.port, retry, timeout, out);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to connect to the tracker.", std::move(rc));
|
||||
}
|
||||
|
||||
TCPSocket& tracker = *out;
|
||||
return std::move(rc)
|
||||
<< [&] { return tracker.NonBlocking(false); }
|
||||
<< [&] { return tracker.RecvTimeout(timeout); }
|
||||
<< [&] { return proto::Magic{}.Verify(&tracker); }
|
||||
<< [&] { return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id); };
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Comm::ConnectTracker(TCPSocket* out) const {
|
||||
return ConnectTrackerImpl(this->TrackerInfo(), this->Timeout(), this->retry_, this->task_id_, out,
|
||||
this->Rank(), this->World());
|
||||
}
|
||||
|
||||
[[nodiscard]] Result ConnectWorkers(Comm const& comm, TCPSocket* listener, std::int32_t lport,
|
||||
proto::PeerInfo ninfo, std::chrono::seconds timeout,
|
||||
std::int32_t retry,
|
||||
std::vector<std::shared_ptr<TCPSocket>>* out_workers) {
|
||||
auto next = std::make_shared<TCPSocket>();
|
||||
auto prev = std::make_shared<TCPSocket>();
|
||||
|
||||
auto rc = Success() << [&] {
|
||||
auto rc = Connect(ninfo.host, ninfo.port, retry, timeout, next.get());
|
||||
if (!rc.OK()) {
|
||||
return Fail("Bootstrap failed to connect to ring next.", std::move(rc));
|
||||
}
|
||||
return rc;
|
||||
} << [&] {
|
||||
return next->NonBlocking(true);
|
||||
} << [&] {
|
||||
SockAddrV4 addr;
|
||||
return listener->Accept(prev.get(), &addr);
|
||||
} << [&] { return prev->NonBlocking(true); };
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
// exchange host name and port
|
||||
std::vector<std::int8_t> buffer(HOST_NAME_MAX * comm.World(), 0);
|
||||
auto s_buffer = common::Span{buffer.data(), buffer.size()};
|
||||
auto next_host = s_buffer.subspan(HOST_NAME_MAX * comm.Rank(), HOST_NAME_MAX);
|
||||
if (next_host.size() < ninfo.host.size()) {
|
||||
return Fail("Got an invalid host name.");
|
||||
}
|
||||
std::copy(ninfo.host.cbegin(), ninfo.host.cend(), next_host.begin());
|
||||
|
||||
auto prev_ch = std::make_shared<Channel>(comm, prev);
|
||||
auto next_ch = std::make_shared<Channel>(comm, next);
|
||||
|
||||
auto block = [&] {
|
||||
for (auto ch : {prev_ch, next_ch}) {
|
||||
auto rc = ch->Block();
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
return Success();
|
||||
};
|
||||
|
||||
rc = std::move(rc) << [&] {
|
||||
return cpu_impl::RingAllgather(comm, s_buffer, HOST_NAME_MAX, 0, prev_ch, next_ch);
|
||||
} << [&] { return block(); };
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to get host names from peers.", std::move(rc));
|
||||
}
|
||||
|
||||
std::vector<std::int32_t> peers_port(comm.World(), -1);
|
||||
peers_port[comm.Rank()] = ninfo.port;
|
||||
rc = std::move(rc) << [&] {
|
||||
auto s_ports = common::Span{reinterpret_cast<std::int8_t*>(peers_port.data()),
|
||||
peers_port.size() * sizeof(ninfo.port)};
|
||||
return cpu_impl::RingAllgather(comm, s_ports, sizeof(ninfo.port), 0, prev_ch, next_ch);
|
||||
} << [&] { return block(); };
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to get the port from peers.", std::move(rc));
|
||||
}
|
||||
|
||||
std::vector<proto::PeerInfo> peers(comm.World());
|
||||
for (auto r = 0; r < comm.World(); ++r) {
|
||||
auto nhost = s_buffer.subspan(HOST_NAME_MAX * r, HOST_NAME_MAX);
|
||||
auto nport = peers_port[r];
|
||||
auto nrank = BootstrapNext(r, comm.World());
|
||||
|
||||
peers[nrank] = {std::string{reinterpret_cast<char const*>(nhost.data())}, nport, nrank};
|
||||
}
|
||||
CHECK_EQ(peers[comm.Rank()].port, lport);
|
||||
for (auto const& p : peers) {
|
||||
CHECK_NE(p.port, -1);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<TCPSocket>>& workers = *out_workers;
|
||||
workers.resize(comm.World());
|
||||
|
||||
for (std::int32_t r = (comm.Rank() + 1); r < comm.World(); ++r) {
|
||||
auto const& peer = peers[r];
|
||||
std::shared_ptr<TCPSocket> worker{TCPSocket::CreatePtr(comm.Domain())};
|
||||
rc = std::move(rc)
|
||||
<< [&] { return Connect(peer.host, peer.port, retry, timeout, worker.get()); }
|
||||
<< [&] { return worker->RecvTimeout(timeout); };
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
auto rank = comm.Rank();
|
||||
auto n_bytes = worker->SendAll(&rank, sizeof(comm.Rank()));
|
||||
if (n_bytes != sizeof(comm.Rank())) {
|
||||
return Fail("Failed to send rank.");
|
||||
}
|
||||
workers[r] = std::move(worker);
|
||||
}
|
||||
|
||||
for (std::int32_t r = 0; r < comm.Rank(); ++r) {
|
||||
SockAddrV4 addr;
|
||||
auto peer = std::shared_ptr<TCPSocket>(TCPSocket::CreatePtr(comm.Domain()));
|
||||
rc = std::move(rc) << [&] { return listener->Accept(peer.get(), &addr); }
|
||||
<< [&] { return peer->RecvTimeout(timeout); };
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
std::int32_t rank{-1};
|
||||
auto n_bytes = peer->RecvAll(&rank, sizeof(rank));
|
||||
if (n_bytes != sizeof(comm.Rank())) {
|
||||
return Fail("Failed to recv rank.");
|
||||
}
|
||||
workers[rank] = std::move(peer);
|
||||
}
|
||||
|
||||
for (std::int32_t r = 0; r < comm.World(); ++r) {
|
||||
if (r == comm.Rank()) {
|
||||
continue;
|
||||
}
|
||||
CHECK(workers[r]);
|
||||
}
|
||||
|
||||
return Success();
|
||||
}
|
||||
|
||||
RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t retry, std::string task_id)
|
||||
: Comm{std::move(host), port, timeout, retry, std::move(task_id)} {
|
||||
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result RabitComm::Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
|
||||
std::string task_id) {
|
||||
TCPSocket tracker;
|
||||
std::int32_t world{-1};
|
||||
auto rc = ConnectTrackerImpl(this->TrackerInfo(), timeout, retry, task_id, &tracker, this->Rank(),
|
||||
world);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Bootstrap failed.", std::move(rc));
|
||||
}
|
||||
|
||||
this->domain_ = tracker.Domain();
|
||||
|
||||
// Start command
|
||||
TCPSocket listener = TCPSocket::Create(tracker.Domain());
|
||||
std::int32_t lport = listener.BindHost();
|
||||
listener.Listen();
|
||||
|
||||
// create worker for listening to error notice.
|
||||
auto domain = tracker.Domain();
|
||||
std::shared_ptr<TCPSocket> error_sock{TCPSocket::CreatePtr(domain)};
|
||||
auto eport = error_sock->BindHost();
|
||||
error_sock->Listen();
|
||||
error_worker_ = std::thread{[this, error_sock = std::move(error_sock)] {
|
||||
auto conn = error_sock->Accept();
|
||||
// On Windows accept returns an invalid socket after network is shutdown.
|
||||
if (conn.IsClosed()) {
|
||||
return;
|
||||
}
|
||||
LOG(WARNING) << "Another worker is running into error.";
|
||||
std::string scmd;
|
||||
conn.Recv(&scmd);
|
||||
auto jcmd = Json::Load(scmd);
|
||||
auto rc = this->Shutdown();
|
||||
if (!rc.OK()) {
|
||||
LOG(WARNING) << "Fail to shutdown worker:" << rc.Report();
|
||||
}
|
||||
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
|
||||
exit(-1);
|
||||
#else
|
||||
LOG(FATAL) << rc.Report();
|
||||
#endif
|
||||
}};
|
||||
error_worker_.detach();
|
||||
|
||||
proto::Start start;
|
||||
rc = std::move(rc) << [&] { return start.WorkerSend(lport, &tracker, eport); }
|
||||
<< [&] { return start.WorkerRecv(&tracker, &world); };
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
this->world_ = world;
|
||||
|
||||
// get ring neighbors
|
||||
std::string snext;
|
||||
tracker.Recv(&snext);
|
||||
auto jnext = Json::Load(StringView{snext});
|
||||
|
||||
proto::PeerInfo ninfo{jnext};
|
||||
|
||||
// get the rank of this worker
|
||||
this->rank_ = BootstrapPrev(ninfo.rank, world);
|
||||
this->tracker_.rank = rank_;
|
||||
|
||||
std::vector<std::shared_ptr<TCPSocket>> workers;
|
||||
rc = ConnectWorkers(*this, &listener, lport, ninfo, timeout, retry, &workers);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
CHECK(this->channels_.empty());
|
||||
for (auto& w : workers) {
|
||||
if (w) {
|
||||
w->SetNoDelay();
|
||||
rc = w->NonBlocking(true);
|
||||
}
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
this->channels_.emplace_back(std::make_shared<Channel>(*this, w));
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
RabitComm::~RabitComm() noexcept(false) {
|
||||
if (!IsDistributed()) {
|
||||
return;
|
||||
}
|
||||
auto rc = this->Shutdown();
|
||||
if (!rc.OK()) {
|
||||
LOG(WARNING) << rc.Report();
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] Result RabitComm::Shutdown() {
|
||||
TCPSocket tracker;
|
||||
return Success() << [&] {
|
||||
return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World());
|
||||
} << [&] {
|
||||
return this->Block();
|
||||
} << [&] {
|
||||
Json jcmd{Object{}};
|
||||
jcmd["cmd"] = Integer{static_cast<std::int32_t>(proto::CMD::kShutdown)};
|
||||
auto scmd = Json::Dump(jcmd);
|
||||
auto n_bytes = tracker.Send(scmd);
|
||||
if (n_bytes != scmd.size()) {
|
||||
return Fail("Faled to send cmd.");
|
||||
}
|
||||
return Success();
|
||||
};
|
||||
}
|
||||
|
||||
[[nodiscard]] Result RabitComm::LogTracker(std::string msg) const {
|
||||
TCPSocket out;
|
||||
proto::Print print;
|
||||
return Success() << [&] { return this->ConnectTracker(&out); }
|
||||
<< [&] { return print.WorkerSend(&out, msg); };
|
||||
}
|
||||
|
||||
[[nodiscard]] Result RabitComm::SignalError(Result const& res) {
|
||||
TCPSocket out;
|
||||
return Success() << [&] { return this->ConnectTracker(&out); }
|
||||
<< [&] { return proto::ErrorCMD{}.WorkerSend(&out, res); };
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
160
src/collective/comm.h
Normal file
160
src/collective/comm.h
Normal file
@ -0,0 +1,160 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <chrono> // for seconds
|
||||
#include <condition_variable> // for condition_variable
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t
|
||||
#include <memory> // for shared_ptr
|
||||
#include <mutex> // for mutex
|
||||
#include <queue> // for queue
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
#include <type_traits> // for remove_const_t
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/timer.h"
|
||||
#include "loop.h" // for Loop
|
||||
#include "protocol.h" // for PeerInfo
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
|
||||
inline constexpr std::int32_t DefaultTimeoutSec() { return 300; } // 5min
|
||||
inline constexpr std::int32_t DefaultRetry() { return 3; }
|
||||
|
||||
// indexing into the ring
|
||||
inline std::int32_t BootstrapNext(std::int32_t r, std::int32_t world) {
|
||||
auto nrank = (r + world + 1) % world;
|
||||
return nrank;
|
||||
}
|
||||
|
||||
inline std::int32_t BootstrapPrev(std::int32_t r, std::int32_t world) {
|
||||
auto nrank = (r + world - 1) % world;
|
||||
return nrank;
|
||||
}
|
||||
|
||||
class Channel;
|
||||
|
||||
/**
|
||||
* @brief Base communicator storing info about the tracker and other communicators.
|
||||
*/
|
||||
class Comm {
|
||||
protected:
|
||||
std::int32_t world_{1};
|
||||
std::int32_t rank_{0};
|
||||
std::chrono::seconds timeout_{DefaultTimeoutSec()};
|
||||
std::int32_t retry_{DefaultRetry()};
|
||||
|
||||
proto::PeerInfo tracker_;
|
||||
SockDomain domain_{SockDomain::kV4};
|
||||
std::thread error_worker_;
|
||||
std::string task_id_;
|
||||
std::vector<std::shared_ptr<Channel>> channels_;
|
||||
std::shared_ptr<Loop> loop_{new Loop{std::chrono::seconds{
|
||||
DefaultTimeoutSec()}}}; // fixme: require federated comm to have a timeout
|
||||
|
||||
public:
|
||||
Comm() = default;
|
||||
Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, std::int32_t retry,
|
||||
std::string task_id);
|
||||
virtual ~Comm() noexcept(false) {} // NOLINT
|
||||
|
||||
Comm(Comm const& that) = delete;
|
||||
Comm& operator=(Comm const& that) = delete;
|
||||
Comm(Comm&& that) = delete;
|
||||
Comm& operator=(Comm&& that) = delete;
|
||||
|
||||
[[nodiscard]] auto TrackerInfo() const { return tracker_; }
|
||||
[[nodiscard]] Result ConnectTracker(TCPSocket* out) const;
|
||||
[[nodiscard]] auto Domain() const { return domain_; }
|
||||
[[nodiscard]] auto Timeout() const { return timeout_; }
|
||||
|
||||
[[nodiscard]] auto Rank() const { return rank_; }
|
||||
[[nodiscard]] auto World() const { return world_; }
|
||||
[[nodiscard]] bool IsDistributed() const { return World() > 1; }
|
||||
void Submit(Loop::Op op) const { loop_->Submit(op); }
|
||||
[[nodiscard]] Result Block() const { return loop_->Block(); }
|
||||
|
||||
[[nodiscard]] virtual std::shared_ptr<Channel> Chan(std::int32_t rank) const {
|
||||
return channels_.at(rank);
|
||||
}
|
||||
[[nodiscard]] virtual bool IsFederated() const = 0;
|
||||
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;
|
||||
|
||||
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
|
||||
};
|
||||
|
||||
class RabitComm : public Comm {
|
||||
[[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
|
||||
std::string task_id);
|
||||
[[nodiscard]] Result Shutdown();
|
||||
|
||||
public:
|
||||
// bootstrapping construction.
|
||||
RabitComm() = default;
|
||||
// ctor for testing where environment is known.
|
||||
RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t retry, std::string task_id);
|
||||
~RabitComm() noexcept(false) override;
|
||||
|
||||
[[nodiscard]] bool IsFederated() const override { return false; }
|
||||
[[nodiscard]] Result LogTracker(std::string msg) const override;
|
||||
|
||||
[[nodiscard]] Result SignalError(Result const&) override;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Communication channel between workers.
|
||||
*/
|
||||
class Channel {
|
||||
std::shared_ptr<TCPSocket> sock_{nullptr};
|
||||
Result rc_;
|
||||
Comm const& comm_;
|
||||
|
||||
public:
|
||||
explicit Channel(Comm const& comm, std::shared_ptr<TCPSocket> sock)
|
||||
: sock_{std::move(sock)}, comm_{comm} {}
|
||||
|
||||
void SendAll(std::int8_t const* ptr, std::size_t n) {
|
||||
Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast<std::int8_t*>(ptr), n, sock_.get(), 0};
|
||||
CHECK(sock_.get());
|
||||
comm_.Submit(std::move(op));
|
||||
}
|
||||
void SendAll(common::Span<std::int8_t const> data) {
|
||||
this->SendAll(data.data(), data.size_bytes());
|
||||
}
|
||||
|
||||
void RecvAll(std::int8_t* ptr, std::size_t n) {
|
||||
Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0};
|
||||
CHECK(sock_.get());
|
||||
comm_.Submit(std::move(op));
|
||||
}
|
||||
void RecvAll(common::Span<std::int8_t> data) { this->RecvAll(data.data(), data.size_bytes()); }
|
||||
|
||||
[[nodiscard]] auto Socket() const { return sock_; }
|
||||
[[nodiscard]] Result Block() { return comm_.Block(); }
|
||||
};
|
||||
|
||||
enum class Op { kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5 };
|
||||
|
||||
template <typename T, typename U = std::conditional_t<std::is_const_v<T>,
|
||||
std::add_const_t<std::int8_t>, std::int8_t>>
|
||||
common::Span<U> EraseType(common::Span<T> data) {
|
||||
auto n_total_bytes = data.size_bytes();
|
||||
auto erased = common::Span{reinterpret_cast<std::add_pointer_t<U>>(data.data()), n_total_bytes};
|
||||
return erased;
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
common::Span<T> RestoreType(common::Span<U> data) {
|
||||
static_assert(std::is_same_v<std::remove_const_t<U>, std::int8_t>);
|
||||
auto n_total_bytes = data.size_bytes();
|
||||
auto restored = common::Span{reinterpret_cast<T*>(data.data()), n_total_bytes / sizeof(T)};
|
||||
return restored;
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
214
src/collective/protocol.h
Normal file
214
src/collective/protocol.h
Normal file
@ -0,0 +1,214 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstdint> // for int32_t
|
||||
#include <string> // for string
|
||||
#include <utility> // for move
|
||||
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective::proto {
|
||||
struct PeerInfo {
|
||||
std::string host;
|
||||
std::int32_t port{-1};
|
||||
std::int32_t rank{-1};
|
||||
|
||||
PeerInfo() = default;
|
||||
PeerInfo(std::string host, std::int32_t port, std::int32_t rank)
|
||||
: host{std::move(host)}, port{port}, rank{rank} {}
|
||||
|
||||
explicit PeerInfo(Json const& peer)
|
||||
: host{get<String>(peer["host"])},
|
||||
port{static_cast<std::int32_t>(get<Integer const>(peer["port"]))},
|
||||
rank{static_cast<std::int32_t>(get<Integer const>(peer["rank"]))} {}
|
||||
|
||||
[[nodiscard]] Json ToJson() const {
|
||||
Json info{Object{}};
|
||||
info["rank"] = rank;
|
||||
info["host"] = String{host};
|
||||
info["port"] = Integer{port};
|
||||
return info;
|
||||
}
|
||||
|
||||
[[nodiscard]] auto HostPort() const { return host + ":" + std::to_string(this->port); }
|
||||
};
|
||||
|
||||
struct Magic {
|
||||
static constexpr std::int32_t kMagic = 0xff99;
|
||||
|
||||
[[nodiscard]] Result Verify(xgboost::collective::TCPSocket* p_sock) {
|
||||
std::int32_t magic{kMagic};
|
||||
auto n_bytes = p_sock->SendAll(&magic, sizeof(magic));
|
||||
if (n_bytes != sizeof(magic)) {
|
||||
return Fail("Failed to verify.");
|
||||
}
|
||||
|
||||
magic = 0;
|
||||
n_bytes = p_sock->RecvAll(&magic, sizeof(magic));
|
||||
if (n_bytes != sizeof(magic)) {
|
||||
return Fail("Failed to verify.");
|
||||
}
|
||||
if (magic != kMagic) {
|
||||
return xgboost::collective::Fail("Invalid verification number.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
};
|
||||
|
||||
enum class CMD : std::int32_t {
|
||||
kInvalid = 0,
|
||||
kStart = 1,
|
||||
kShutdown = 2,
|
||||
kError = 3,
|
||||
kPrint = 4,
|
||||
};
|
||||
|
||||
struct Connect {
|
||||
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::int32_t world, std::int32_t rank,
|
||||
std::string task_id) const {
|
||||
Json jinit{Object{}};
|
||||
jinit["world_size"] = Integer{world};
|
||||
jinit["rank"] = Integer{rank};
|
||||
jinit["task_id"] = String{task_id};
|
||||
std::string msg;
|
||||
Json::Dump(jinit, &msg);
|
||||
auto n_bytes = tracker->Send(msg);
|
||||
if (n_bytes != msg.size()) {
|
||||
return Fail("Failed to send init command from worker.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
[[nodiscard]] Result TrackerRecv(TCPSocket* sock, std::int32_t* world, std::int32_t* rank,
|
||||
std::string* task_id) const {
|
||||
std::string init;
|
||||
sock->Recv(&init);
|
||||
auto jinit = Json::Load(StringView{init});
|
||||
*world = get<Integer const>(jinit["world_size"]);
|
||||
*rank = get<Integer const>(jinit["rank"]);
|
||||
*task_id = get<String const>(jinit["task_id"]);
|
||||
return Success();
|
||||
}
|
||||
};
|
||||
|
||||
class Start {
|
||||
private:
|
||||
[[nodiscard]] Result TrackerSend(std::int32_t world, TCPSocket* worker) const {
|
||||
Json jcmd{Object{}};
|
||||
jcmd["world_size"] = Integer{world};
|
||||
auto scmd = Json::Dump(jcmd);
|
||||
auto n_bytes = worker->Send(scmd);
|
||||
if (n_bytes != scmd.size()) {
|
||||
return Fail("Failed to send init command from tracker.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
|
||||
public:
|
||||
[[nodiscard]] Result WorkerSend(std::int32_t lport, TCPSocket* tracker,
|
||||
std::int32_t eport) const {
|
||||
Json jcmd{Object{}};
|
||||
jcmd["cmd"] = Integer{static_cast<std::int32_t>(CMD::kStart)};
|
||||
jcmd["port"] = Integer{lport};
|
||||
jcmd["error_port"] = Integer{eport};
|
||||
auto scmd = Json::Dump(jcmd);
|
||||
auto n_bytes = tracker->Send(scmd);
|
||||
if (n_bytes != scmd.size()) {
|
||||
return Fail("Failed to send init command from worker.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
[[nodiscard]] Result WorkerRecv(TCPSocket* tracker, std::int32_t* p_world) const {
|
||||
std::string scmd;
|
||||
auto n_bytes = tracker->Recv(&scmd);
|
||||
if (n_bytes <= 0) {
|
||||
return Fail("Failed to recv init command from tracker.");
|
||||
}
|
||||
auto jcmd = Json::Load(scmd);
|
||||
auto world = get<Integer const>(jcmd["world_size"]);
|
||||
if (world <= 0) {
|
||||
return Fail("Invalid world size.");
|
||||
}
|
||||
*p_world = world;
|
||||
return Success();
|
||||
}
|
||||
[[nodiscard]] Result TrackerHandle(Json jcmd, std::int32_t* recv_world, std::int32_t world,
|
||||
std::int32_t* p_port, TCPSocket* p_sock,
|
||||
std::int32_t* eport) const {
|
||||
*p_port = get<Integer const>(jcmd["port"]);
|
||||
if (*p_port <= 0) {
|
||||
return Fail("Invalid port.");
|
||||
}
|
||||
if (*recv_world != -1) {
|
||||
return Fail("Invalid initialization sequence.");
|
||||
}
|
||||
*recv_world = world;
|
||||
*eport = get<Integer const>(jcmd["error_port"]);
|
||||
return TrackerSend(world, p_sock);
|
||||
}
|
||||
};
|
||||
|
||||
struct Print {
|
||||
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::string msg) const {
|
||||
Json jcmd{Object{}};
|
||||
jcmd["cmd"] = Integer{static_cast<std::int32_t>(CMD::kPrint)};
|
||||
jcmd["msg"] = String{std::move(msg)};
|
||||
auto scmd = Json::Dump(jcmd);
|
||||
auto n_bytes = tracker->Send(scmd);
|
||||
if (n_bytes != scmd.size()) {
|
||||
return Fail("Failed to send print command from worker.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
[[nodiscard]] Result TrackerHandle(Json jcmd, std::string* p_msg) const {
|
||||
if (!IsA<String>(jcmd["msg"])) {
|
||||
return Fail("Invalid print command.");
|
||||
}
|
||||
auto msg = get<String const>(jcmd["msg"]);
|
||||
*p_msg = msg;
|
||||
return Success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ErrorCMD {
|
||||
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, Result const& res) const {
|
||||
auto msg = res.Report();
|
||||
auto code = res.Code().value();
|
||||
Json jcmd{Object{}};
|
||||
jcmd["msg"] = String{std::move(msg)};
|
||||
jcmd["code"] = Integer{code};
|
||||
jcmd["cmd"] = Integer{static_cast<std::int32_t>(CMD::kError)};
|
||||
auto scmd = Json::Dump(jcmd);
|
||||
auto n_bytes = tracker->Send(scmd);
|
||||
if (n_bytes != scmd.size()) {
|
||||
return Fail("Failed to send error command from worker.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
[[nodiscard]] Result TrackerHandle(Json jcmd, std::string* p_msg, int* p_code) const {
|
||||
if (!IsA<String>(jcmd["msg"]) || !IsA<Integer>(jcmd["code"])) {
|
||||
return Fail("Invalid error command.");
|
||||
}
|
||||
auto msg = get<String const>(jcmd["msg"]);
|
||||
auto code = get<Integer const>(jcmd["code"]);
|
||||
*p_msg = msg;
|
||||
*p_code = code;
|
||||
return Success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ShutdownCMD {
|
||||
[[nodiscard]] Result Send(TCPSocket* peer) const {
|
||||
Json jcmd{Object{}};
|
||||
jcmd["cmd"] = Integer{static_cast<std::int32_t>(proto::CMD::kShutdown)};
|
||||
auto scmd = Json::Dump(jcmd);
|
||||
auto n_bytes = peer->Send(scmd);
|
||||
if (n_bytes != scmd.size()) {
|
||||
return Fail("Failed to send shutdown command from worker.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::collective::proto
|
||||
@ -15,12 +15,232 @@
|
||||
#include <ws2tcpip.h>
|
||||
#endif // defined(_WIN32)
|
||||
|
||||
#include <string> // for string
|
||||
#include <algorithm> // for sort
|
||||
#include <chrono> // for seconds
|
||||
#include <cstdint> // for int32_t
|
||||
#include <string> // for string
|
||||
#include <utility> // for move, forward
|
||||
|
||||
#include "../common/json_utils.h"
|
||||
#include "comm.h"
|
||||
#include "protocol.h" // for kMagic, PeerInfo
|
||||
#include "tracker.h"
|
||||
#include "xgboost/collective/result.h" // for Result, Fail, Success
|
||||
#include "xgboost/collective/socket.h" // for GetHostName, FailWithCode, MakeSockAddress, ...
|
||||
#include "xgboost/json.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
Tracker::Tracker(Json const& config)
|
||||
: n_workers_{static_cast<std::int32_t>(
|
||||
RequiredArg<Integer const>(config, "n_workers", __func__))},
|
||||
port_{static_cast<std::int32_t>(OptionalArg<Integer const>(config, "port", Integer::Int{0}))},
|
||||
timeout_{std::chrono::seconds{OptionalArg<Integer const>(
|
||||
config, "timeout", static_cast<std::int64_t>(collective::DefaultTimeoutSec()))}} {}
|
||||
|
||||
RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr)
|
||||
: sock_{std::move(sock)} {
|
||||
auto host = addr.Addr();
|
||||
|
||||
std::int32_t rank{0};
|
||||
rc_ = Success()
|
||||
<< [&] { return proto::Magic{}.Verify(&sock_); }
|
||||
<< [&] { return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_); };
|
||||
if (!rc_.OK()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::string cmd;
|
||||
sock_.Recv(&cmd);
|
||||
auto jcmd = Json::Load(StringView{cmd});
|
||||
cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"]));
|
||||
std::int32_t port{0};
|
||||
if (cmd_ == proto::CMD::kStart) {
|
||||
proto::Start start;
|
||||
rc_ = start.TrackerHandle(jcmd, &world_, world, &port, &sock_, &eport_);
|
||||
} else if (cmd_ == proto::CMD::kPrint) {
|
||||
proto::Print print;
|
||||
rc_ = print.TrackerHandle(jcmd, &msg_);
|
||||
} else if (cmd_ == proto::CMD::kError) {
|
||||
proto::ErrorCMD error;
|
||||
rc_ = error.TrackerHandle(jcmd, &msg_, &code_);
|
||||
}
|
||||
if (!rc_.OK()) {
|
||||
return;
|
||||
}
|
||||
|
||||
info_ = proto::PeerInfo{host, port, rank};
|
||||
}
|
||||
|
||||
RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
|
||||
std::string self;
|
||||
auto rc = collective::GetHostAddress(&self);
|
||||
auto host = OptionalArg<String>(config, "host", self);
|
||||
|
||||
listener_ = TCPSocket::Create(SockDomain::kV4);
|
||||
rc = listener_.Bind(host, &this->port_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
listener_.Listen();
|
||||
}
|
||||
|
||||
Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
auto& workers = *p_workers;
|
||||
|
||||
std::sort(workers.begin(), workers.end(), WorkerCmp{});
|
||||
|
||||
std::vector<std::thread> bootstrap_threads;
|
||||
for (std::int32_t r = 0; r < n_workers_; ++r) {
|
||||
auto& worker = workers[r];
|
||||
auto next = BootstrapNext(r, n_workers_);
|
||||
auto const& next_w = workers[next];
|
||||
bootstrap_threads.emplace_back([next, &worker, &next_w] {
|
||||
auto jnext = proto::PeerInfo{next_w.Host(), next_w.Port(), next}.ToJson();
|
||||
std::string str;
|
||||
Json::Dump(jnext, &str);
|
||||
worker.Send(StringView{str});
|
||||
});
|
||||
}
|
||||
|
||||
for (auto& t : bootstrap_threads) {
|
||||
t.join();
|
||||
}
|
||||
|
||||
for (auto const& w : workers) {
|
||||
worker_error_handles_.emplace_back(w.Host(), w.ErrorPort());
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::future<Result> RabitTracker::Run() {
|
||||
// a state machine to keep track of consistency.
|
||||
struct State {
|
||||
std::int32_t const n_workers;
|
||||
|
||||
std::int32_t n_shutdown{0};
|
||||
bool during_restart{false};
|
||||
std::vector<WorkerProxy> pending;
|
||||
|
||||
explicit State(std::int32_t world) : n_workers{world} {}
|
||||
State(State const& that) = delete;
|
||||
State& operator=(State&& that) = delete;
|
||||
|
||||
void Start(WorkerProxy&& worker) {
|
||||
CHECK_LT(pending.size(), n_workers);
|
||||
CHECK_LE(n_shutdown, n_workers);
|
||||
|
||||
pending.emplace_back(std::forward<WorkerProxy>(worker));
|
||||
|
||||
CHECK_LE(pending.size(), n_workers);
|
||||
}
|
||||
void Shutdown() {
|
||||
CHECK_GE(n_shutdown, 0);
|
||||
CHECK_LT(n_shutdown, n_workers);
|
||||
|
||||
++n_shutdown;
|
||||
|
||||
CHECK_LE(n_shutdown, n_workers);
|
||||
}
|
||||
void Error() {
|
||||
CHECK_LE(pending.size(), n_workers);
|
||||
CHECK_LE(n_shutdown, n_workers);
|
||||
|
||||
during_restart = true;
|
||||
}
|
||||
[[nodiscard]] bool Ready() const {
|
||||
CHECK_LE(pending.size(), n_workers);
|
||||
return static_cast<std::int32_t>(pending.size()) == n_workers;
|
||||
}
|
||||
void Bootstrap() {
|
||||
CHECK_EQ(pending.size(), n_workers);
|
||||
CHECK_LE(n_shutdown, n_workers);
|
||||
|
||||
// A reset.
|
||||
n_shutdown = 0;
|
||||
during_restart = false;
|
||||
pending.clear();
|
||||
}
|
||||
[[nodiscard]] bool ShouldContinue() const {
|
||||
CHECK_LE(pending.size(), n_workers);
|
||||
CHECK_LE(n_shutdown, n_workers);
|
||||
// - Without error, we should shutdown after all workers are offline.
|
||||
// - With error, all workers are offline, and we have during_restart as true.
|
||||
return n_shutdown != n_workers || during_restart;
|
||||
}
|
||||
};
|
||||
|
||||
return std::async(std::launch::async, [this] {
|
||||
State state{this->n_workers_};
|
||||
|
||||
while (state.ShouldContinue()) {
|
||||
TCPSocket sock;
|
||||
SockAddrV4 addr;
|
||||
auto rc = listener_.Accept(&sock, &addr);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to accept connection.", std::move(rc));
|
||||
}
|
||||
|
||||
auto worker = WorkerProxy{n_workers_, std::move(sock), std::move(addr)};
|
||||
if (!worker.Status().OK()) {
|
||||
return Fail("Failed to initialize worker proxy.", std::move(worker.Status()));
|
||||
}
|
||||
switch (worker.Command()) {
|
||||
case proto::CMD::kStart: {
|
||||
state.Start(std::move(worker));
|
||||
if (state.Ready()) {
|
||||
rc = this->Bootstrap(&state.pending);
|
||||
state.Bootstrap();
|
||||
}
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
case proto::CMD::kShutdown: {
|
||||
state.Shutdown();
|
||||
continue;
|
||||
}
|
||||
case proto::CMD::kError: {
|
||||
if (state.during_restart) {
|
||||
continue;
|
||||
}
|
||||
state.Error();
|
||||
auto msg = worker.Msg();
|
||||
auto code = worker.Code();
|
||||
LOG(WARNING) << "Recieved error from [" << worker.Host() << ":" << worker.Rank()
|
||||
<< "]: " << msg << " code:" << code;
|
||||
auto host = worker.Host();
|
||||
// We signal all workers for the error, if they haven't aborted already.
|
||||
for (auto& w : worker_error_handles_) {
|
||||
if (w.first == host) {
|
||||
continue;
|
||||
}
|
||||
TCPSocket out;
|
||||
// retry is set to 1, just let the worker timeout or error. Otherwise the
|
||||
// tracker and the worker might be waiting for each other.
|
||||
auto rc = Connect(w.first, w.second, 1, timeout_, &out);
|
||||
// send signal to stop the worker.
|
||||
proto::ShutdownCMD shutdown;
|
||||
rc = shutdown.Send(&out);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to inform workers to stop.");
|
||||
}
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
case proto::CMD::kPrint: {
|
||||
LOG(CONSOLE) << worker.Msg();
|
||||
continue;
|
||||
}
|
||||
case proto::CMD::kInvalid:
|
||||
default: {
|
||||
return Fail("Invalid command received.");
|
||||
}
|
||||
}
|
||||
}
|
||||
return Success();
|
||||
});
|
||||
}
|
||||
|
||||
[[nodiscard]] Result GetHostAddress(std::string* out) {
|
||||
auto rc = GetHostName(out);
|
||||
if (!rc.OK()) {
|
||||
|
||||
@ -2,11 +2,137 @@
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <string> // for string
|
||||
#include <chrono> // for seconds
|
||||
#include <cstdint> // for int32_t
|
||||
#include <future> // for future
|
||||
#include <string> // for string
|
||||
#include <utility> // for pair
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "protocol.h"
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
/**
|
||||
*
|
||||
* @brief Implementation of RABIT tracker.
|
||||
*
|
||||
* * What is a tracker
|
||||
*
|
||||
* The implementation of collective follows what RABIT did in the past. It requires a
|
||||
* tracker to coordinate initialization and error recovery of workers. While the
|
||||
* original implementation attempted to attain error resislient inside the collective
|
||||
* module, which turned out be too challenging due to large amount of external
|
||||
* states. The new implementation here differs from RABIT in the way that neither state
|
||||
* recovery nor resislient is handled inside the collective, it merely provides the
|
||||
* mechanism to signal error to other workers through the use of a centralized tracker.
|
||||
*
|
||||
* There are three major functionalities provided the a tracker, namely:
|
||||
* - Initialization. Share the node addresses among all workers.
|
||||
* - Logging.
|
||||
* - Signal error. If an exception is thrown in one (or many) of the workers, it can
|
||||
* signal an error to the tracker and the tracker will notify other workers.
|
||||
*/
|
||||
class Tracker {
|
||||
protected:
|
||||
std::int32_t n_workers_{0};
|
||||
std::int32_t port_{-1};
|
||||
std::chrono::seconds timeout_{0};
|
||||
|
||||
public:
|
||||
explicit Tracker(Json const& config);
|
||||
Tracker(std::int32_t n_worders, std::int32_t port, std::chrono::seconds timeout)
|
||||
: n_workers_{n_worders}, port_{port}, timeout_{timeout} {}
|
||||
|
||||
virtual ~Tracker() noexcept(false){}; // NOLINT
|
||||
[[nodiscard]] virtual std::future<Result> Run() = 0;
|
||||
[[nodiscard]] virtual Json WorkerArgs() const = 0;
|
||||
[[nodiscard]] std::chrono::seconds Timeout() const { return timeout_; }
|
||||
};
|
||||
|
||||
class RabitTracker : public Tracker {
|
||||
// a wrapper for connected worker socket.
|
||||
class WorkerProxy {
|
||||
TCPSocket sock_;
|
||||
proto::PeerInfo info_;
|
||||
std::int32_t eport_{0};
|
||||
std::int32_t world_{-1};
|
||||
std::string task_id_;
|
||||
|
||||
proto::CMD cmd_{proto::CMD::kInvalid};
|
||||
std::string msg_;
|
||||
std::int32_t code_{0};
|
||||
Result rc_;
|
||||
|
||||
public:
|
||||
explicit WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr);
|
||||
WorkerProxy(WorkerProxy const& that) = delete;
|
||||
WorkerProxy(WorkerProxy&& that) = default;
|
||||
WorkerProxy& operator=(WorkerProxy const&) = delete;
|
||||
WorkerProxy& operator=(WorkerProxy&&) = default;
|
||||
|
||||
[[nodiscard]] auto Host() const { return info_.host; }
|
||||
[[nodiscard]] auto TaskID() const { return task_id_; }
|
||||
[[nodiscard]] auto Port() const { return info_.port; }
|
||||
[[nodiscard]] auto Rank() const { return info_.rank; }
|
||||
[[nodiscard]] auto ErrorPort() const { return eport_; }
|
||||
[[nodiscard]] auto Command() const { return cmd_; }
|
||||
[[nodiscard]] auto Msg() const { return msg_; }
|
||||
[[nodiscard]] auto Code() const { return code_; }
|
||||
|
||||
[[nodiscard]] Result const& Status() const { return rc_; }
|
||||
[[nodiscard]] Result& Status() { return rc_; }
|
||||
|
||||
void Send(StringView value) { this->sock_.Send(value); }
|
||||
};
|
||||
// provide an ordering for workers, this helps us get deterministic topology.
|
||||
struct WorkerCmp {
|
||||
[[nodiscard]] bool operator()(WorkerProxy const& lhs, WorkerProxy const& rhs) {
|
||||
auto const& lh = lhs.Host();
|
||||
auto const& rh = rhs.Host();
|
||||
|
||||
if (lh != rh) {
|
||||
return lh < rh;
|
||||
}
|
||||
return lhs.TaskID() < rhs.TaskID();
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
std::string host_;
|
||||
// record for how to reach out to workers if error happens.
|
||||
std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_;
|
||||
// listening socket for incoming workers.
|
||||
TCPSocket listener_;
|
||||
|
||||
Result Bootstrap(std::vector<WorkerProxy>* p_workers);
|
||||
|
||||
public:
|
||||
explicit RabitTracker(StringView host, std::int32_t n_worders, std::int32_t port,
|
||||
std::chrono::seconds timeout)
|
||||
: Tracker{n_worders, port, timeout}, host_{host.c_str(), host.size()} {
|
||||
listener_ = TCPSocket::Create(SockDomain::kV4);
|
||||
auto rc = listener_.Bind(host, &this->port_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
listener_.Listen();
|
||||
}
|
||||
|
||||
explicit RabitTracker(Json const& config);
|
||||
~RabitTracker() noexcept(false) override = default;
|
||||
|
||||
std::future<Result> Run() override;
|
||||
|
||||
[[nodiscard]] std::int32_t Port() const { return port_; }
|
||||
[[nodiscard]] Json WorkerArgs() const override {
|
||||
Json args{Object{}};
|
||||
args["DMLC_TRACKER_URI"] = String{host_};
|
||||
args["DMLC_TRACKER_PORT"] = this->Port();
|
||||
return args;
|
||||
}
|
||||
};
|
||||
|
||||
// Prob the public IP address of the host, need a better method.
|
||||
//
|
||||
// This is directly translated from the previous Python implementation, we should find a
|
||||
|
||||
47
tests/cpp/collective/test_comm.cc
Normal file
47
tests/cpp/collective/test_comm.cc
Normal file
@ -0,0 +1,47 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "../../../src/collective/comm.h"
|
||||
#include "test_worker.h"
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class CommTest : public TrackerTest {};
|
||||
} // namespace
|
||||
|
||||
TEST_F(CommTest, Channel) {
|
||||
auto n_workers = 4;
|
||||
RabitTracker tracker{host, n_workers, 0, timeout};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] {
|
||||
WorkerForTest worker{host, port, timeout, n_workers, i};
|
||||
if (i % 2 == 0) {
|
||||
auto p_chan = worker.Comm().Chan(i + 1);
|
||||
p_chan->SendAll(
|
||||
EraseType(common::Span<std::int32_t const>{&i, static_cast<std::size_t>(1)}));
|
||||
auto rc = p_chan->Block();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
} else {
|
||||
auto p_chan = worker.Comm().Chan(i - 1);
|
||||
std::int32_t r{-1};
|
||||
p_chan->RecvAll(EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
|
||||
auto rc = p_chan->Block();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
ASSERT_EQ(r, i - 1);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for (auto &w : workers) {
|
||||
w.join();
|
||||
}
|
||||
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@ -7,7 +7,7 @@
|
||||
#include <cerrno> // EADDRNOTAVAIL
|
||||
#include <system_error> // std::error_code, std::system_category
|
||||
|
||||
#include "net_test.h" // for SocketTest
|
||||
#include "test_worker.h" // for SocketTest
|
||||
|
||||
namespace xgboost::collective {
|
||||
TEST_F(SocketTest, Basic) {
|
||||
|
||||
@ -1,18 +1,67 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include "../../../src/collective/tracker.h" // for GetHostAddress
|
||||
#include "net_test.h" // for SocketTest
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <chrono> // for seconds
|
||||
#include <cstdint> // for int32_t
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../src/collective/comm.h"
|
||||
#include "test_worker.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class TrackerTest : public SocketTest {};
|
||||
class PrintWorker : public WorkerForTest {
|
||||
public:
|
||||
using WorkerForTest::WorkerForTest;
|
||||
|
||||
void Print() {
|
||||
auto rc = comm_.LogTracker("ack:" + std::to_string(this->comm_.Rank()));
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST_F(TrackerTest, GetHostAddress) {
|
||||
std::string host;
|
||||
auto rc = GetHostAddress(&host);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
ASSERT_TRUE(host.find("127.") == std::string::npos);
|
||||
TEST_F(TrackerTest, Bootstrap) {
|
||||
RabitTracker tracker{host, n_workers, 0, timeout};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] { WorkerForTest worker{host, port, timeout, n_workers, i}; });
|
||||
}
|
||||
for (auto &w : workers) {
|
||||
w.join();
|
||||
}
|
||||
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
}
|
||||
|
||||
TEST_F(TrackerTest, Print) {
|
||||
RabitTracker tracker{host, n_workers, 0, timeout};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] {
|
||||
PrintWorker worker{host, port, timeout, n_workers, i};
|
||||
worker.Print();
|
||||
});
|
||||
}
|
||||
|
||||
for (auto &w : workers) {
|
||||
w.join();
|
||||
}
|
||||
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
}
|
||||
|
||||
TEST_F(TrackerTest, GetHostAddress) { ASSERT_TRUE(host.find("127.") == std::string::npos); }
|
||||
} // namespace xgboost::collective
|
||||
|
||||
91
tests/cpp/collective/test_worker.h
Normal file
91
tests/cpp/collective/test_worker.h
Normal file
@ -0,0 +1,91 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <chrono> // for seconds
|
||||
#include <cstdint> // for int32_t
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../src/collective/comm.h"
|
||||
#include "../../../src/collective/tracker.h" // for GetHostAddress
|
||||
#include "../helpers.h" // for FileExists
|
||||
|
||||
namespace xgboost::collective {
|
||||
class WorkerForTest {
|
||||
std::string tracker_host_;
|
||||
std::int32_t tracker_port_;
|
||||
std::int32_t world_size_;
|
||||
|
||||
protected:
|
||||
std::int32_t retry_{1};
|
||||
std::string task_id_;
|
||||
RabitComm comm_;
|
||||
|
||||
public:
|
||||
WorkerForTest(std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t world, std::int32_t rank)
|
||||
: tracker_host_{std::move(host)},
|
||||
tracker_port_{port},
|
||||
world_size_{world},
|
||||
task_id_{"t:" + std::to_string(rank)},
|
||||
comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_} {
|
||||
CHECK_EQ(world_size_, comm_.World());
|
||||
}
|
||||
virtual ~WorkerForTest() = default;
|
||||
auto& Comm() { return comm_; }
|
||||
|
||||
void LimitSockBuf(std::int32_t n_bytes) {
|
||||
for (std::int32_t i = 0; i < comm_.World(); ++i) {
|
||||
if (i != comm_.Rank()) {
|
||||
ASSERT_TRUE(comm_.Chan(i)->Socket()->NonBlocking());
|
||||
ASSERT_TRUE(comm_.Chan(i)->Socket()->SetBufSize(n_bytes).OK());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class SocketTest : public ::testing::Test {
|
||||
protected:
|
||||
std::string skip_msg_{"Skipping IPv6 test"};
|
||||
|
||||
bool SkipTest() {
|
||||
std::string path{"/sys/module/ipv6/parameters/disable"};
|
||||
if (FileExists(path)) {
|
||||
std::ifstream fin(path);
|
||||
if (!fin) {
|
||||
return true;
|
||||
}
|
||||
std::string s_value;
|
||||
fin >> s_value;
|
||||
auto value = std::stoi(s_value);
|
||||
if (value != 0) {
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override { system::SocketStartup(); }
|
||||
void TearDown() override { system::SocketFinalize(); }
|
||||
};
|
||||
|
||||
class TrackerTest : public SocketTest {
|
||||
public:
|
||||
std::int32_t n_workers{2};
|
||||
std::chrono::seconds timeout{1};
|
||||
std::string host;
|
||||
|
||||
void SetUp() override {
|
||||
SocketTest::SetUp();
|
||||
auto rc = GetHostAddress(&host);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
Loading…
x
Reference in New Issue
Block a user