[coll] Implement a new tracker and a communicator. (#9650)

* [coll] Implement a new tracker and a communicator.

The new tracker and communicators communicate through the use of JSON documents. Along
with which, communicators are aware of each other.
This commit is contained in:
Jiaming Yuan 2023-10-12 12:49:16 +08:00 committed by GitHub
parent 2e42f33fc1
commit 946ae1c440
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1345 additions and 16 deletions

View File

@ -98,9 +98,13 @@ OBJECTS= \
$(PKGROOT)/src/context.o \
$(PKGROOT)/src/logging.o \
$(PKGROOT)/src/global_config.o \
$(PKGROOT)/src/collective/allgather.o \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/communicator.o \
$(PKGROOT)/src/collective/in_memory_communicator.o \
$(PKGROOT)/src/collective/in_memory_handler.o \
$(PKGROOT)/src/collective/loop.o \
$(PKGROOT)/src/collective/socket.o \
$(PKGROOT)/src/common/charconv.o \
$(PKGROOT)/src/common/column_matrix.o \

View File

@ -98,9 +98,13 @@ OBJECTS= \
$(PKGROOT)/src/context.o \
$(PKGROOT)/src/logging.o \
$(PKGROOT)/src/global_config.o \
$(PKGROOT)/src/collective/allgather.o \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/communicator.o \
$(PKGROOT)/src/collective/in_memory_communicator.o \
$(PKGROOT)/src/collective/in_memory_handler.o \
$(PKGROOT)/src/collective/loop.o \
$(PKGROOT)/src/collective/socket.o \
$(PKGROOT)/src/common/charconv.o \
$(PKGROOT)/src/common/column_matrix.o \

View File

@ -157,4 +157,13 @@ struct Result {
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev) {
return Result{std::move(msg), std::move(errc), std::forward<Result>(prev)};
}
// We don't have monad, a simple helper would do.
template <typename Fn>
Result operator<<(Result&& r, Fn&& fn) {
if (!r.OK()) {
return std::forward<Result>(r);
}
return fn();
}
} // namespace xgboost::collective

View File

@ -380,11 +380,18 @@ class TCPSocket {
}
[[nodiscard]] bool NonBlocking() const { return non_blocking_; }
[[nodiscard]] Result RecvTimeout(std::chrono::seconds timeout) {
timeval tv;
// https://stackoverflow.com/questions/2876024/linux-is-there-a-read-or-recv-from-socket-with-timeout
#if defined(_WIN32)
DWORD tv = timeout.count() * 1000;
auto rc =
setsockopt(Handle(), SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char *>(&tv), sizeof(tv));
#else
struct timeval tv;
tv.tv_sec = timeout.count();
tv.tv_usec = 0;
auto rc = setsockopt(Handle(), SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char const *>(&tv),
sizeof(tv));
#endif
if (rc != 0) {
return system::FailWithCode("Failed to set timeout on recv.");
}
@ -425,7 +432,12 @@ class TCPSocket {
*/
TCPSocket Accept() {
HandleT newfd = accept(Handle(), nullptr, nullptr);
if (newfd == InvalidSocket()) {
#if defined(_WIN32)
auto interrupt = WSAEINTR;
#else
auto interrupt = EINTR;
#endif
if (newfd == InvalidSocket() && system::LastError() != interrupt) {
system::ThrowAtError("accept");
}
TCPSocket newsock{newfd};
@ -468,7 +480,7 @@ class TCPSocket {
/**
* \brief Bind socket to INADDR_ANY, return the port selected by the OS.
*/
in_port_t BindHost() {
[[nodiscard]] in_port_t BindHost() {
if (Domain() == SockDomain::kV6) {
auto addr = SockAddrV6::InaddrAny();
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
@ -539,7 +551,7 @@ class TCPSocket {
/**
* \brief Send data, without error then all data should be sent.
*/
auto SendAll(void const *buf, std::size_t len) {
[[nodiscard]] auto SendAll(void const *buf, std::size_t len) {
char const *_buf = reinterpret_cast<const char *>(buf);
std::size_t ndone = 0;
while (ndone < len) {
@ -558,7 +570,7 @@ class TCPSocket {
/**
* \brief Receive data, without error then all data should be received.
*/
auto RecvAll(void *buf, std::size_t len) {
[[nodiscard]] auto RecvAll(void *buf, std::size_t len) {
char *_buf = reinterpret_cast<char *>(buf);
std::size_t ndone = 0;
while (ndone < len) {
@ -612,7 +624,15 @@ class TCPSocket {
*/
void Close() {
if (InvalidSocket() != handle_) {
#if defined(_WIN32)
auto rc = system::CloseSocket(handle_);
// it's possible that we close TCP sockets after finalizing WSA due to detached thread.
if (rc != 0 && system::LastError() != WSANOTINITIALISED) {
system::ThrowAtError("close", rc);
}
#else
xgboost_CHECK_SYS_CALL(system::CloseSocket(handle_), 0);
#endif
handle_ = InvalidSocket();
}
}
@ -634,6 +654,24 @@ class TCPSocket {
socket.domain_ = domain;
#endif // defined(__APPLE__)
return socket;
#endif // defined(xgboost_IS_MINGW)
}
static TCPSocket *CreatePtr(SockDomain domain) {
#if defined(xgboost_IS_MINGW)
MingWError();
return nullptr;
#else
auto fd = socket(static_cast<std::int32_t>(domain), SOCK_STREAM, 0);
if (fd == InvalidSocket()) {
system::ThrowAtError("socket");
}
auto socket = new TCPSocket{fd};
#if defined(__APPLE__)
socket->domain_ = domain;
#endif // defined(__APPLE__)
return socket;
#endif // defined(xgboost_IS_MINGW)
}
};

View File

@ -0,0 +1,42 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include "allgather.h"
#include <algorithm> // for min
#include <cstddef> // for size_t
#include <cstdint> // for int8_t
#include <memory> // for shared_ptr
#include "comm.h" // for Comm, Channel
#include "xgboost/span.h" // for Span
namespace xgboost::collective::cpu_impl {
Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size_t segment_size,
std::int32_t worker_off, std::shared_ptr<Channel> prev_ch,
std::shared_ptr<Channel> next_ch) {
auto world = comm.World();
auto rank = comm.Rank();
CHECK_LT(worker_off, world);
for (std::int32_t r = 0; r < world; ++r) {
auto send_rank = (rank + world - r + worker_off) % world;
auto send_off = send_rank * segment_size;
send_off = std::min(send_off, data.size_bytes());
auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off));
next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
auto recv_off = recv_rank * segment_size;
recv_off = std::min(recv_off, data.size_bytes());
auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off));
prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
auto rc = prev_ch->Block();
if (!rc.OK()) {
return rc;
}
}
return Success();
}
} // namespace xgboost::collective::cpu_impl

View File

@ -0,0 +1,23 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <memory> // for shared_ptr
#include "comm.h" // for Comm, Channel
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
namespace cpu_impl {
/**
* @param worker_off Segment offset. For example, if the rank 2 worker specifis worker_off
* = 1, then it owns the third segment.
*/
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data,
std::size_t segment_size, std::int32_t worker_off,
std::shared_ptr<Channel> prev_ch,
std::shared_ptr<Channel> next_ch);
} // namespace cpu_impl
} // namespace xgboost::collective

302
src/collective/comm.cc Normal file
View File

@ -0,0 +1,302 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include "comm.h"
#include <algorithm> // for copy
#include <chrono> // for seconds
#include <memory> // for shared_ptr
#include <string> // for string
#include <utility> // for move, forward
#include "allgather.h"
#include "protocol.h" // for kMagic
#include "xgboost/collective/socket.h" // for TCPSocket
#include "xgboost/json.h" // for Json, Object
namespace xgboost::collective {
Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t retry, std::string task_id)
: timeout_{timeout},
retry_{retry},
tracker_{host, port, -1},
task_id_{std::move(task_id)},
loop_{std::make_shared<Loop>(timeout)} {}
Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
std::string const& task_id, TCPSocket* out, std::int32_t rank,
std::int32_t world) {
// get information from tracker
CHECK(!info.host.empty());
auto rc = Connect(info.host, info.port, retry, timeout, out);
if (!rc.OK()) {
return Fail("Failed to connect to the tracker.", std::move(rc));
}
TCPSocket& tracker = *out;
return std::move(rc)
<< [&] { return tracker.NonBlocking(false); }
<< [&] { return tracker.RecvTimeout(timeout); }
<< [&] { return proto::Magic{}.Verify(&tracker); }
<< [&] { return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id); };
}
[[nodiscard]] Result Comm::ConnectTracker(TCPSocket* out) const {
return ConnectTrackerImpl(this->TrackerInfo(), this->Timeout(), this->retry_, this->task_id_, out,
this->Rank(), this->World());
}
[[nodiscard]] Result ConnectWorkers(Comm const& comm, TCPSocket* listener, std::int32_t lport,
proto::PeerInfo ninfo, std::chrono::seconds timeout,
std::int32_t retry,
std::vector<std::shared_ptr<TCPSocket>>* out_workers) {
auto next = std::make_shared<TCPSocket>();
auto prev = std::make_shared<TCPSocket>();
auto rc = Success() << [&] {
auto rc = Connect(ninfo.host, ninfo.port, retry, timeout, next.get());
if (!rc.OK()) {
return Fail("Bootstrap failed to connect to ring next.", std::move(rc));
}
return rc;
} << [&] {
return next->NonBlocking(true);
} << [&] {
SockAddrV4 addr;
return listener->Accept(prev.get(), &addr);
} << [&] { return prev->NonBlocking(true); };
if (!rc.OK()) {
return rc;
}
// exchange host name and port
std::vector<std::int8_t> buffer(HOST_NAME_MAX * comm.World(), 0);
auto s_buffer = common::Span{buffer.data(), buffer.size()};
auto next_host = s_buffer.subspan(HOST_NAME_MAX * comm.Rank(), HOST_NAME_MAX);
if (next_host.size() < ninfo.host.size()) {
return Fail("Got an invalid host name.");
}
std::copy(ninfo.host.cbegin(), ninfo.host.cend(), next_host.begin());
auto prev_ch = std::make_shared<Channel>(comm, prev);
auto next_ch = std::make_shared<Channel>(comm, next);
auto block = [&] {
for (auto ch : {prev_ch, next_ch}) {
auto rc = ch->Block();
if (!rc.OK()) {
return rc;
}
}
return Success();
};
rc = std::move(rc) << [&] {
return cpu_impl::RingAllgather(comm, s_buffer, HOST_NAME_MAX, 0, prev_ch, next_ch);
} << [&] { return block(); };
if (!rc.OK()) {
return Fail("Failed to get host names from peers.", std::move(rc));
}
std::vector<std::int32_t> peers_port(comm.World(), -1);
peers_port[comm.Rank()] = ninfo.port;
rc = std::move(rc) << [&] {
auto s_ports = common::Span{reinterpret_cast<std::int8_t*>(peers_port.data()),
peers_port.size() * sizeof(ninfo.port)};
return cpu_impl::RingAllgather(comm, s_ports, sizeof(ninfo.port), 0, prev_ch, next_ch);
} << [&] { return block(); };
if (!rc.OK()) {
return Fail("Failed to get the port from peers.", std::move(rc));
}
std::vector<proto::PeerInfo> peers(comm.World());
for (auto r = 0; r < comm.World(); ++r) {
auto nhost = s_buffer.subspan(HOST_NAME_MAX * r, HOST_NAME_MAX);
auto nport = peers_port[r];
auto nrank = BootstrapNext(r, comm.World());
peers[nrank] = {std::string{reinterpret_cast<char const*>(nhost.data())}, nport, nrank};
}
CHECK_EQ(peers[comm.Rank()].port, lport);
for (auto const& p : peers) {
CHECK_NE(p.port, -1);
}
std::vector<std::shared_ptr<TCPSocket>>& workers = *out_workers;
workers.resize(comm.World());
for (std::int32_t r = (comm.Rank() + 1); r < comm.World(); ++r) {
auto const& peer = peers[r];
std::shared_ptr<TCPSocket> worker{TCPSocket::CreatePtr(comm.Domain())};
rc = std::move(rc)
<< [&] { return Connect(peer.host, peer.port, retry, timeout, worker.get()); }
<< [&] { return worker->RecvTimeout(timeout); };
if (!rc.OK()) {
return rc;
}
auto rank = comm.Rank();
auto n_bytes = worker->SendAll(&rank, sizeof(comm.Rank()));
if (n_bytes != sizeof(comm.Rank())) {
return Fail("Failed to send rank.");
}
workers[r] = std::move(worker);
}
for (std::int32_t r = 0; r < comm.Rank(); ++r) {
SockAddrV4 addr;
auto peer = std::shared_ptr<TCPSocket>(TCPSocket::CreatePtr(comm.Domain()));
rc = std::move(rc) << [&] { return listener->Accept(peer.get(), &addr); }
<< [&] { return peer->RecvTimeout(timeout); };
if (!rc.OK()) {
return rc;
}
std::int32_t rank{-1};
auto n_bytes = peer->RecvAll(&rank, sizeof(rank));
if (n_bytes != sizeof(comm.Rank())) {
return Fail("Failed to recv rank.");
}
workers[rank] = std::move(peer);
}
for (std::int32_t r = 0; r < comm.World(); ++r) {
if (r == comm.Rank()) {
continue;
}
CHECK(workers[r]);
}
return Success();
}
RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t retry, std::string task_id)
: Comm{std::move(host), port, timeout, retry, std::move(task_id)} {
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
CHECK(rc.OK()) << rc.Report();
}
[[nodiscard]] Result RabitComm::Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
std::string task_id) {
TCPSocket tracker;
std::int32_t world{-1};
auto rc = ConnectTrackerImpl(this->TrackerInfo(), timeout, retry, task_id, &tracker, this->Rank(),
world);
if (!rc.OK()) {
return Fail("Bootstrap failed.", std::move(rc));
}
this->domain_ = tracker.Domain();
// Start command
TCPSocket listener = TCPSocket::Create(tracker.Domain());
std::int32_t lport = listener.BindHost();
listener.Listen();
// create worker for listening to error notice.
auto domain = tracker.Domain();
std::shared_ptr<TCPSocket> error_sock{TCPSocket::CreatePtr(domain)};
auto eport = error_sock->BindHost();
error_sock->Listen();
error_worker_ = std::thread{[this, error_sock = std::move(error_sock)] {
auto conn = error_sock->Accept();
// On Windows accept returns an invalid socket after network is shutdown.
if (conn.IsClosed()) {
return;
}
LOG(WARNING) << "Another worker is running into error.";
std::string scmd;
conn.Recv(&scmd);
auto jcmd = Json::Load(scmd);
auto rc = this->Shutdown();
if (!rc.OK()) {
LOG(WARNING) << "Fail to shutdown worker:" << rc.Report();
}
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
exit(-1);
#else
LOG(FATAL) << rc.Report();
#endif
}};
error_worker_.detach();
proto::Start start;
rc = std::move(rc) << [&] { return start.WorkerSend(lport, &tracker, eport); }
<< [&] { return start.WorkerRecv(&tracker, &world); };
if (!rc.OK()) {
return rc;
}
this->world_ = world;
// get ring neighbors
std::string snext;
tracker.Recv(&snext);
auto jnext = Json::Load(StringView{snext});
proto::PeerInfo ninfo{jnext};
// get the rank of this worker
this->rank_ = BootstrapPrev(ninfo.rank, world);
this->tracker_.rank = rank_;
std::vector<std::shared_ptr<TCPSocket>> workers;
rc = ConnectWorkers(*this, &listener, lport, ninfo, timeout, retry, &workers);
if (!rc.OK()) {
return rc;
}
CHECK(this->channels_.empty());
for (auto& w : workers) {
if (w) {
w->SetNoDelay();
rc = w->NonBlocking(true);
}
if (!rc.OK()) {
return rc;
}
this->channels_.emplace_back(std::make_shared<Channel>(*this, w));
}
return rc;
}
RabitComm::~RabitComm() noexcept(false) {
if (!IsDistributed()) {
return;
}
auto rc = this->Shutdown();
if (!rc.OK()) {
LOG(WARNING) << rc.Report();
}
}
[[nodiscard]] Result RabitComm::Shutdown() {
TCPSocket tracker;
return Success() << [&] {
return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World());
} << [&] {
return this->Block();
} << [&] {
Json jcmd{Object{}};
jcmd["cmd"] = Integer{static_cast<std::int32_t>(proto::CMD::kShutdown)};
auto scmd = Json::Dump(jcmd);
auto n_bytes = tracker.Send(scmd);
if (n_bytes != scmd.size()) {
return Fail("Faled to send cmd.");
}
return Success();
};
}
[[nodiscard]] Result RabitComm::LogTracker(std::string msg) const {
TCPSocket out;
proto::Print print;
return Success() << [&] { return this->ConnectTracker(&out); }
<< [&] { return print.WorkerSend(&out, msg); };
}
[[nodiscard]] Result RabitComm::SignalError(Result const& res) {
TCPSocket out;
return Success() << [&] { return this->ConnectTracker(&out); }
<< [&] { return proto::ErrorCMD{}.WorkerSend(&out, res); };
}
} // namespace xgboost::collective

160
src/collective/comm.h Normal file
View File

@ -0,0 +1,160 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <chrono> // for seconds
#include <condition_variable> // for condition_variable
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <memory> // for shared_ptr
#include <mutex> // for mutex
#include <queue> // for queue
#include <string> // for string
#include <thread> // for thread
#include <type_traits> // for remove_const_t
#include <utility> // for move
#include <vector> // for vector
#include "../common/timer.h"
#include "loop.h" // for Loop
#include "protocol.h" // for PeerInfo
#include "xgboost/collective/result.h" // for Result
#include "xgboost/collective/socket.h" // for TCPSocket
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
inline constexpr std::int32_t DefaultTimeoutSec() { return 300; } // 5min
inline constexpr std::int32_t DefaultRetry() { return 3; }
// indexing into the ring
inline std::int32_t BootstrapNext(std::int32_t r, std::int32_t world) {
auto nrank = (r + world + 1) % world;
return nrank;
}
inline std::int32_t BootstrapPrev(std::int32_t r, std::int32_t world) {
auto nrank = (r + world - 1) % world;
return nrank;
}
class Channel;
/**
* @brief Base communicator storing info about the tracker and other communicators.
*/
class Comm {
protected:
std::int32_t world_{1};
std::int32_t rank_{0};
std::chrono::seconds timeout_{DefaultTimeoutSec()};
std::int32_t retry_{DefaultRetry()};
proto::PeerInfo tracker_;
SockDomain domain_{SockDomain::kV4};
std::thread error_worker_;
std::string task_id_;
std::vector<std::shared_ptr<Channel>> channels_;
std::shared_ptr<Loop> loop_{new Loop{std::chrono::seconds{
DefaultTimeoutSec()}}}; // fixme: require federated comm to have a timeout
public:
Comm() = default;
Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, std::int32_t retry,
std::string task_id);
virtual ~Comm() noexcept(false) {} // NOLINT
Comm(Comm const& that) = delete;
Comm& operator=(Comm const& that) = delete;
Comm(Comm&& that) = delete;
Comm& operator=(Comm&& that) = delete;
[[nodiscard]] auto TrackerInfo() const { return tracker_; }
[[nodiscard]] Result ConnectTracker(TCPSocket* out) const;
[[nodiscard]] auto Domain() const { return domain_; }
[[nodiscard]] auto Timeout() const { return timeout_; }
[[nodiscard]] auto Rank() const { return rank_; }
[[nodiscard]] auto World() const { return world_; }
[[nodiscard]] bool IsDistributed() const { return World() > 1; }
void Submit(Loop::Op op) const { loop_->Submit(op); }
[[nodiscard]] Result Block() const { return loop_->Block(); }
[[nodiscard]] virtual std::shared_ptr<Channel> Chan(std::int32_t rank) const {
return channels_.at(rank);
}
[[nodiscard]] virtual bool IsFederated() const = 0;
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
};
class RabitComm : public Comm {
[[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
std::string task_id);
[[nodiscard]] Result Shutdown();
public:
// bootstrapping construction.
RabitComm() = default;
// ctor for testing where environment is known.
RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t retry, std::string task_id);
~RabitComm() noexcept(false) override;
[[nodiscard]] bool IsFederated() const override { return false; }
[[nodiscard]] Result LogTracker(std::string msg) const override;
[[nodiscard]] Result SignalError(Result const&) override;
};
/**
* @brief Communication channel between workers.
*/
class Channel {
std::shared_ptr<TCPSocket> sock_{nullptr};
Result rc_;
Comm const& comm_;
public:
explicit Channel(Comm const& comm, std::shared_ptr<TCPSocket> sock)
: sock_{std::move(sock)}, comm_{comm} {}
void SendAll(std::int8_t const* ptr, std::size_t n) {
Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast<std::int8_t*>(ptr), n, sock_.get(), 0};
CHECK(sock_.get());
comm_.Submit(std::move(op));
}
void SendAll(common::Span<std::int8_t const> data) {
this->SendAll(data.data(), data.size_bytes());
}
void RecvAll(std::int8_t* ptr, std::size_t n) {
Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0};
CHECK(sock_.get());
comm_.Submit(std::move(op));
}
void RecvAll(common::Span<std::int8_t> data) { this->RecvAll(data.data(), data.size_bytes()); }
[[nodiscard]] auto Socket() const { return sock_; }
[[nodiscard]] Result Block() { return comm_.Block(); }
};
enum class Op { kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5 };
template <typename T, typename U = std::conditional_t<std::is_const_v<T>,
std::add_const_t<std::int8_t>, std::int8_t>>
common::Span<U> EraseType(common::Span<T> data) {
auto n_total_bytes = data.size_bytes();
auto erased = common::Span{reinterpret_cast<std::add_pointer_t<U>>(data.data()), n_total_bytes};
return erased;
}
template <typename T, typename U>
common::Span<T> RestoreType(common::Span<U> data) {
static_assert(std::is_same_v<std::remove_const_t<U>, std::int8_t>);
auto n_total_bytes = data.size_bytes();
auto restored = common::Span{reinterpret_cast<T*>(data.data()), n_total_bytes / sizeof(T)};
return restored;
}
} // namespace xgboost::collective

214
src/collective/protocol.h Normal file
View File

@ -0,0 +1,214 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int32_t
#include <string> // for string
#include <utility> // for move
#include "xgboost/collective/result.h" // for Result
#include "xgboost/collective/socket.h" // for TCPSocket
#include "xgboost/json.h" // for Json
namespace xgboost::collective::proto {
struct PeerInfo {
std::string host;
std::int32_t port{-1};
std::int32_t rank{-1};
PeerInfo() = default;
PeerInfo(std::string host, std::int32_t port, std::int32_t rank)
: host{std::move(host)}, port{port}, rank{rank} {}
explicit PeerInfo(Json const& peer)
: host{get<String>(peer["host"])},
port{static_cast<std::int32_t>(get<Integer const>(peer["port"]))},
rank{static_cast<std::int32_t>(get<Integer const>(peer["rank"]))} {}
[[nodiscard]] Json ToJson() const {
Json info{Object{}};
info["rank"] = rank;
info["host"] = String{host};
info["port"] = Integer{port};
return info;
}
[[nodiscard]] auto HostPort() const { return host + ":" + std::to_string(this->port); }
};
struct Magic {
static constexpr std::int32_t kMagic = 0xff99;
[[nodiscard]] Result Verify(xgboost::collective::TCPSocket* p_sock) {
std::int32_t magic{kMagic};
auto n_bytes = p_sock->SendAll(&magic, sizeof(magic));
if (n_bytes != sizeof(magic)) {
return Fail("Failed to verify.");
}
magic = 0;
n_bytes = p_sock->RecvAll(&magic, sizeof(magic));
if (n_bytes != sizeof(magic)) {
return Fail("Failed to verify.");
}
if (magic != kMagic) {
return xgboost::collective::Fail("Invalid verification number.");
}
return Success();
}
};
enum class CMD : std::int32_t {
kInvalid = 0,
kStart = 1,
kShutdown = 2,
kError = 3,
kPrint = 4,
};
struct Connect {
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::int32_t world, std::int32_t rank,
std::string task_id) const {
Json jinit{Object{}};
jinit["world_size"] = Integer{world};
jinit["rank"] = Integer{rank};
jinit["task_id"] = String{task_id};
std::string msg;
Json::Dump(jinit, &msg);
auto n_bytes = tracker->Send(msg);
if (n_bytes != msg.size()) {
return Fail("Failed to send init command from worker.");
}
return Success();
}
[[nodiscard]] Result TrackerRecv(TCPSocket* sock, std::int32_t* world, std::int32_t* rank,
std::string* task_id) const {
std::string init;
sock->Recv(&init);
auto jinit = Json::Load(StringView{init});
*world = get<Integer const>(jinit["world_size"]);
*rank = get<Integer const>(jinit["rank"]);
*task_id = get<String const>(jinit["task_id"]);
return Success();
}
};
class Start {
private:
[[nodiscard]] Result TrackerSend(std::int32_t world, TCPSocket* worker) const {
Json jcmd{Object{}};
jcmd["world_size"] = Integer{world};
auto scmd = Json::Dump(jcmd);
auto n_bytes = worker->Send(scmd);
if (n_bytes != scmd.size()) {
return Fail("Failed to send init command from tracker.");
}
return Success();
}
public:
[[nodiscard]] Result WorkerSend(std::int32_t lport, TCPSocket* tracker,
std::int32_t eport) const {
Json jcmd{Object{}};
jcmd["cmd"] = Integer{static_cast<std::int32_t>(CMD::kStart)};
jcmd["port"] = Integer{lport};
jcmd["error_port"] = Integer{eport};
auto scmd = Json::Dump(jcmd);
auto n_bytes = tracker->Send(scmd);
if (n_bytes != scmd.size()) {
return Fail("Failed to send init command from worker.");
}
return Success();
}
[[nodiscard]] Result WorkerRecv(TCPSocket* tracker, std::int32_t* p_world) const {
std::string scmd;
auto n_bytes = tracker->Recv(&scmd);
if (n_bytes <= 0) {
return Fail("Failed to recv init command from tracker.");
}
auto jcmd = Json::Load(scmd);
auto world = get<Integer const>(jcmd["world_size"]);
if (world <= 0) {
return Fail("Invalid world size.");
}
*p_world = world;
return Success();
}
[[nodiscard]] Result TrackerHandle(Json jcmd, std::int32_t* recv_world, std::int32_t world,
std::int32_t* p_port, TCPSocket* p_sock,
std::int32_t* eport) const {
*p_port = get<Integer const>(jcmd["port"]);
if (*p_port <= 0) {
return Fail("Invalid port.");
}
if (*recv_world != -1) {
return Fail("Invalid initialization sequence.");
}
*recv_world = world;
*eport = get<Integer const>(jcmd["error_port"]);
return TrackerSend(world, p_sock);
}
};
struct Print {
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::string msg) const {
Json jcmd{Object{}};
jcmd["cmd"] = Integer{static_cast<std::int32_t>(CMD::kPrint)};
jcmd["msg"] = String{std::move(msg)};
auto scmd = Json::Dump(jcmd);
auto n_bytes = tracker->Send(scmd);
if (n_bytes != scmd.size()) {
return Fail("Failed to send print command from worker.");
}
return Success();
}
[[nodiscard]] Result TrackerHandle(Json jcmd, std::string* p_msg) const {
if (!IsA<String>(jcmd["msg"])) {
return Fail("Invalid print command.");
}
auto msg = get<String const>(jcmd["msg"]);
*p_msg = msg;
return Success();
}
};
struct ErrorCMD {
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, Result const& res) const {
auto msg = res.Report();
auto code = res.Code().value();
Json jcmd{Object{}};
jcmd["msg"] = String{std::move(msg)};
jcmd["code"] = Integer{code};
jcmd["cmd"] = Integer{static_cast<std::int32_t>(CMD::kError)};
auto scmd = Json::Dump(jcmd);
auto n_bytes = tracker->Send(scmd);
if (n_bytes != scmd.size()) {
return Fail("Failed to send error command from worker.");
}
return Success();
}
[[nodiscard]] Result TrackerHandle(Json jcmd, std::string* p_msg, int* p_code) const {
if (!IsA<String>(jcmd["msg"]) || !IsA<Integer>(jcmd["code"])) {
return Fail("Invalid error command.");
}
auto msg = get<String const>(jcmd["msg"]);
auto code = get<Integer const>(jcmd["code"]);
*p_msg = msg;
*p_code = code;
return Success();
}
};
struct ShutdownCMD {
[[nodiscard]] Result Send(TCPSocket* peer) const {
Json jcmd{Object{}};
jcmd["cmd"] = Integer{static_cast<std::int32_t>(proto::CMD::kShutdown)};
auto scmd = Json::Dump(jcmd);
auto n_bytes = peer->Send(scmd);
if (n_bytes != scmd.size()) {
return Fail("Failed to send shutdown command from worker.");
}
return Success();
}
};
} // namespace xgboost::collective::proto

View File

@ -15,12 +15,232 @@
#include <ws2tcpip.h>
#endif // defined(_WIN32)
#include <string> // for string
#include <algorithm> // for sort
#include <chrono> // for seconds
#include <cstdint> // for int32_t
#include <string> // for string
#include <utility> // for move, forward
#include "../common/json_utils.h"
#include "comm.h"
#include "protocol.h" // for kMagic, PeerInfo
#include "tracker.h"
#include "xgboost/collective/result.h" // for Result, Fail, Success
#include "xgboost/collective/socket.h" // for GetHostName, FailWithCode, MakeSockAddress, ...
#include "xgboost/json.h"
namespace xgboost::collective {
Tracker::Tracker(Json const& config)
: n_workers_{static_cast<std::int32_t>(
RequiredArg<Integer const>(config, "n_workers", __func__))},
port_{static_cast<std::int32_t>(OptionalArg<Integer const>(config, "port", Integer::Int{0}))},
timeout_{std::chrono::seconds{OptionalArg<Integer const>(
config, "timeout", static_cast<std::int64_t>(collective::DefaultTimeoutSec()))}} {}
RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr)
: sock_{std::move(sock)} {
auto host = addr.Addr();
std::int32_t rank{0};
rc_ = Success()
<< [&] { return proto::Magic{}.Verify(&sock_); }
<< [&] { return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_); };
if (!rc_.OK()) {
return;
}
std::string cmd;
sock_.Recv(&cmd);
auto jcmd = Json::Load(StringView{cmd});
cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"]));
std::int32_t port{0};
if (cmd_ == proto::CMD::kStart) {
proto::Start start;
rc_ = start.TrackerHandle(jcmd, &world_, world, &port, &sock_, &eport_);
} else if (cmd_ == proto::CMD::kPrint) {
proto::Print print;
rc_ = print.TrackerHandle(jcmd, &msg_);
} else if (cmd_ == proto::CMD::kError) {
proto::ErrorCMD error;
rc_ = error.TrackerHandle(jcmd, &msg_, &code_);
}
if (!rc_.OK()) {
return;
}
info_ = proto::PeerInfo{host, port, rank};
}
RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
std::string self;
auto rc = collective::GetHostAddress(&self);
auto host = OptionalArg<String>(config, "host", self);
listener_ = TCPSocket::Create(SockDomain::kV4);
rc = listener_.Bind(host, &this->port_);
CHECK(rc.OK()) << rc.Report();
listener_.Listen();
}
Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
auto& workers = *p_workers;
std::sort(workers.begin(), workers.end(), WorkerCmp{});
std::vector<std::thread> bootstrap_threads;
for (std::int32_t r = 0; r < n_workers_; ++r) {
auto& worker = workers[r];
auto next = BootstrapNext(r, n_workers_);
auto const& next_w = workers[next];
bootstrap_threads.emplace_back([next, &worker, &next_w] {
auto jnext = proto::PeerInfo{next_w.Host(), next_w.Port(), next}.ToJson();
std::string str;
Json::Dump(jnext, &str);
worker.Send(StringView{str});
});
}
for (auto& t : bootstrap_threads) {
t.join();
}
for (auto const& w : workers) {
worker_error_handles_.emplace_back(w.Host(), w.ErrorPort());
}
return Success();
}
[[nodiscard]] std::future<Result> RabitTracker::Run() {
// a state machine to keep track of consistency.
struct State {
std::int32_t const n_workers;
std::int32_t n_shutdown{0};
bool during_restart{false};
std::vector<WorkerProxy> pending;
explicit State(std::int32_t world) : n_workers{world} {}
State(State const& that) = delete;
State& operator=(State&& that) = delete;
void Start(WorkerProxy&& worker) {
CHECK_LT(pending.size(), n_workers);
CHECK_LE(n_shutdown, n_workers);
pending.emplace_back(std::forward<WorkerProxy>(worker));
CHECK_LE(pending.size(), n_workers);
}
void Shutdown() {
CHECK_GE(n_shutdown, 0);
CHECK_LT(n_shutdown, n_workers);
++n_shutdown;
CHECK_LE(n_shutdown, n_workers);
}
void Error() {
CHECK_LE(pending.size(), n_workers);
CHECK_LE(n_shutdown, n_workers);
during_restart = true;
}
[[nodiscard]] bool Ready() const {
CHECK_LE(pending.size(), n_workers);
return static_cast<std::int32_t>(pending.size()) == n_workers;
}
void Bootstrap() {
CHECK_EQ(pending.size(), n_workers);
CHECK_LE(n_shutdown, n_workers);
// A reset.
n_shutdown = 0;
during_restart = false;
pending.clear();
}
[[nodiscard]] bool ShouldContinue() const {
CHECK_LE(pending.size(), n_workers);
CHECK_LE(n_shutdown, n_workers);
// - Without error, we should shutdown after all workers are offline.
// - With error, all workers are offline, and we have during_restart as true.
return n_shutdown != n_workers || during_restart;
}
};
return std::async(std::launch::async, [this] {
State state{this->n_workers_};
while (state.ShouldContinue()) {
TCPSocket sock;
SockAddrV4 addr;
auto rc = listener_.Accept(&sock, &addr);
if (!rc.OK()) {
return Fail("Failed to accept connection.", std::move(rc));
}
auto worker = WorkerProxy{n_workers_, std::move(sock), std::move(addr)};
if (!worker.Status().OK()) {
return Fail("Failed to initialize worker proxy.", std::move(worker.Status()));
}
switch (worker.Command()) {
case proto::CMD::kStart: {
state.Start(std::move(worker));
if (state.Ready()) {
rc = this->Bootstrap(&state.pending);
state.Bootstrap();
}
if (!rc.OK()) {
return rc;
}
continue;
}
case proto::CMD::kShutdown: {
state.Shutdown();
continue;
}
case proto::CMD::kError: {
if (state.during_restart) {
continue;
}
state.Error();
auto msg = worker.Msg();
auto code = worker.Code();
LOG(WARNING) << "Recieved error from [" << worker.Host() << ":" << worker.Rank()
<< "]: " << msg << " code:" << code;
auto host = worker.Host();
// We signal all workers for the error, if they haven't aborted already.
for (auto& w : worker_error_handles_) {
if (w.first == host) {
continue;
}
TCPSocket out;
// retry is set to 1, just let the worker timeout or error. Otherwise the
// tracker and the worker might be waiting for each other.
auto rc = Connect(w.first, w.second, 1, timeout_, &out);
// send signal to stop the worker.
proto::ShutdownCMD shutdown;
rc = shutdown.Send(&out);
if (!rc.OK()) {
return Fail("Failed to inform workers to stop.");
}
}
continue;
}
case proto::CMD::kPrint: {
LOG(CONSOLE) << worker.Msg();
continue;
}
case proto::CMD::kInvalid:
default: {
return Fail("Invalid command received.");
}
}
}
return Success();
});
}
[[nodiscard]] Result GetHostAddress(std::string* out) {
auto rc = GetHostName(out);
if (!rc.OK()) {

View File

@ -2,11 +2,137 @@
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <string> // for string
#include <chrono> // for seconds
#include <cstdint> // for int32_t
#include <future> // for future
#include <string> // for string
#include <utility> // for pair
#include <vector> // for vector
#include "protocol.h"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/collective/socket.h" // for TCPSocket
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
/**
*
* @brief Implementation of RABIT tracker.
*
* * What is a tracker
*
* The implementation of collective follows what RABIT did in the past. It requires a
* tracker to coordinate initialization and error recovery of workers. While the
* original implementation attempted to attain error resislient inside the collective
* module, which turned out be too challenging due to large amount of external
* states. The new implementation here differs from RABIT in the way that neither state
* recovery nor resislient is handled inside the collective, it merely provides the
* mechanism to signal error to other workers through the use of a centralized tracker.
*
* There are three major functionalities provided the a tracker, namely:
* - Initialization. Share the node addresses among all workers.
* - Logging.
* - Signal error. If an exception is thrown in one (or many) of the workers, it can
* signal an error to the tracker and the tracker will notify other workers.
*/
class Tracker {
protected:
std::int32_t n_workers_{0};
std::int32_t port_{-1};
std::chrono::seconds timeout_{0};
public:
explicit Tracker(Json const& config);
Tracker(std::int32_t n_worders, std::int32_t port, std::chrono::seconds timeout)
: n_workers_{n_worders}, port_{port}, timeout_{timeout} {}
virtual ~Tracker() noexcept(false){}; // NOLINT
[[nodiscard]] virtual std::future<Result> Run() = 0;
[[nodiscard]] virtual Json WorkerArgs() const = 0;
[[nodiscard]] std::chrono::seconds Timeout() const { return timeout_; }
};
class RabitTracker : public Tracker {
// a wrapper for connected worker socket.
class WorkerProxy {
TCPSocket sock_;
proto::PeerInfo info_;
std::int32_t eport_{0};
std::int32_t world_{-1};
std::string task_id_;
proto::CMD cmd_{proto::CMD::kInvalid};
std::string msg_;
std::int32_t code_{0};
Result rc_;
public:
explicit WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr);
WorkerProxy(WorkerProxy const& that) = delete;
WorkerProxy(WorkerProxy&& that) = default;
WorkerProxy& operator=(WorkerProxy const&) = delete;
WorkerProxy& operator=(WorkerProxy&&) = default;
[[nodiscard]] auto Host() const { return info_.host; }
[[nodiscard]] auto TaskID() const { return task_id_; }
[[nodiscard]] auto Port() const { return info_.port; }
[[nodiscard]] auto Rank() const { return info_.rank; }
[[nodiscard]] auto ErrorPort() const { return eport_; }
[[nodiscard]] auto Command() const { return cmd_; }
[[nodiscard]] auto Msg() const { return msg_; }
[[nodiscard]] auto Code() const { return code_; }
[[nodiscard]] Result const& Status() const { return rc_; }
[[nodiscard]] Result& Status() { return rc_; }
void Send(StringView value) { this->sock_.Send(value); }
};
// provide an ordering for workers, this helps us get deterministic topology.
struct WorkerCmp {
[[nodiscard]] bool operator()(WorkerProxy const& lhs, WorkerProxy const& rhs) {
auto const& lh = lhs.Host();
auto const& rh = rhs.Host();
if (lh != rh) {
return lh < rh;
}
return lhs.TaskID() < rhs.TaskID();
}
};
private:
std::string host_;
// record for how to reach out to workers if error happens.
std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_;
// listening socket for incoming workers.
TCPSocket listener_;
Result Bootstrap(std::vector<WorkerProxy>* p_workers);
public:
explicit RabitTracker(StringView host, std::int32_t n_worders, std::int32_t port,
std::chrono::seconds timeout)
: Tracker{n_worders, port, timeout}, host_{host.c_str(), host.size()} {
listener_ = TCPSocket::Create(SockDomain::kV4);
auto rc = listener_.Bind(host, &this->port_);
CHECK(rc.OK()) << rc.Report();
listener_.Listen();
}
explicit RabitTracker(Json const& config);
~RabitTracker() noexcept(false) override = default;
std::future<Result> Run() override;
[[nodiscard]] std::int32_t Port() const { return port_; }
[[nodiscard]] Json WorkerArgs() const override {
Json args{Object{}};
args["DMLC_TRACKER_URI"] = String{host_};
args["DMLC_TRACKER_PORT"] = this->Port();
return args;
}
};
// Prob the public IP address of the host, need a better method.
//
// This is directly translated from the previous Python implementation, we should find a

View File

@ -0,0 +1,47 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include "../../../src/collective/comm.h"
#include "test_worker.h"
namespace xgboost::collective {
namespace {
class CommTest : public TrackerTest {};
} // namespace
TEST_F(CommTest, Channel) {
auto n_workers = 4;
RabitTracker tracker{host, n_workers, 0, timeout};
auto fut = tracker.Run();
std::vector<std::thread> workers;
std::int32_t port = tracker.Port();
for (std::int32_t i = 0; i < n_workers; ++i) {
workers.emplace_back([=] {
WorkerForTest worker{host, port, timeout, n_workers, i};
if (i % 2 == 0) {
auto p_chan = worker.Comm().Chan(i + 1);
p_chan->SendAll(
EraseType(common::Span<std::int32_t const>{&i, static_cast<std::size_t>(1)}));
auto rc = p_chan->Block();
ASSERT_TRUE(rc.OK()) << rc.Report();
} else {
auto p_chan = worker.Comm().Chan(i - 1);
std::int32_t r{-1};
p_chan->RecvAll(EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
auto rc = p_chan->Block();
ASSERT_TRUE(rc.OK()) << rc.Report();
ASSERT_EQ(r, i - 1);
}
});
}
for (auto &w : workers) {
w.join();
}
ASSERT_TRUE(fut.get().OK());
}
} // namespace xgboost::collective

View File

@ -7,7 +7,7 @@
#include <cerrno> // EADDRNOTAVAIL
#include <system_error> // std::error_code, std::system_category
#include "net_test.h" // for SocketTest
#include "test_worker.h" // for SocketTest
namespace xgboost::collective {
TEST_F(SocketTest, Basic) {

View File

@ -1,18 +1,67 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include "../../../src/collective/tracker.h" // for GetHostAddress
#include "net_test.h" // for SocketTest
#include <gtest/gtest.h>
#include <chrono> // for seconds
#include <cstdint> // for int32_t
#include <string> // for string
#include <thread> // for thread
#include <vector> // for vector
#include "../../../src/collective/comm.h"
#include "test_worker.h"
namespace xgboost::collective {
namespace {
class TrackerTest : public SocketTest {};
class PrintWorker : public WorkerForTest {
public:
using WorkerForTest::WorkerForTest;
void Print() {
auto rc = comm_.LogTracker("ack:" + std::to_string(this->comm_.Rank()));
ASSERT_TRUE(rc.OK()) << rc.Report();
}
};
} // namespace
TEST_F(TrackerTest, GetHostAddress) {
std::string host;
auto rc = GetHostAddress(&host);
ASSERT_TRUE(rc.OK());
ASSERT_TRUE(host.find("127.") == std::string::npos);
TEST_F(TrackerTest, Bootstrap) {
RabitTracker tracker{host, n_workers, 0, timeout};
auto fut = tracker.Run();
std::vector<std::thread> workers;
std::int32_t port = tracker.Port();
for (std::int32_t i = 0; i < n_workers; ++i) {
workers.emplace_back([=] { WorkerForTest worker{host, port, timeout, n_workers, i}; });
}
for (auto &w : workers) {
w.join();
}
ASSERT_TRUE(fut.get().OK());
}
TEST_F(TrackerTest, Print) {
RabitTracker tracker{host, n_workers, 0, timeout};
auto fut = tracker.Run();
std::vector<std::thread> workers;
std::int32_t port = tracker.Port();
for (std::int32_t i = 0; i < n_workers; ++i) {
workers.emplace_back([=] {
PrintWorker worker{host, port, timeout, n_workers, i};
worker.Print();
});
}
for (auto &w : workers) {
w.join();
}
ASSERT_TRUE(fut.get().OK());
}
TEST_F(TrackerTest, GetHostAddress) { ASSERT_TRUE(host.find("127.") == std::string::npos); }
} // namespace xgboost::collective

View File

@ -0,0 +1,91 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <chrono> // for seconds
#include <cstdint> // for int32_t
#include <string> // for string
#include <thread> // for thread
#include <utility> // for move
#include <vector> // for vector
#include "../../../src/collective/comm.h"
#include "../../../src/collective/tracker.h" // for GetHostAddress
#include "../helpers.h" // for FileExists
namespace xgboost::collective {
class WorkerForTest {
std::string tracker_host_;
std::int32_t tracker_port_;
std::int32_t world_size_;
protected:
std::int32_t retry_{1};
std::string task_id_;
RabitComm comm_;
public:
WorkerForTest(std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t world, std::int32_t rank)
: tracker_host_{std::move(host)},
tracker_port_{port},
world_size_{world},
task_id_{"t:" + std::to_string(rank)},
comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_} {
CHECK_EQ(world_size_, comm_.World());
}
virtual ~WorkerForTest() = default;
auto& Comm() { return comm_; }
void LimitSockBuf(std::int32_t n_bytes) {
for (std::int32_t i = 0; i < comm_.World(); ++i) {
if (i != comm_.Rank()) {
ASSERT_TRUE(comm_.Chan(i)->Socket()->NonBlocking());
ASSERT_TRUE(comm_.Chan(i)->Socket()->SetBufSize(n_bytes).OK());
}
}
}
};
class SocketTest : public ::testing::Test {
protected:
std::string skip_msg_{"Skipping IPv6 test"};
bool SkipTest() {
std::string path{"/sys/module/ipv6/parameters/disable"};
if (FileExists(path)) {
std::ifstream fin(path);
if (!fin) {
return true;
}
std::string s_value;
fin >> s_value;
auto value = std::stoi(s_value);
if (value != 0) {
return true;
}
} else {
return true;
}
return false;
}
protected:
void SetUp() override { system::SocketStartup(); }
void TearDown() override { system::SocketFinalize(); }
};
class TrackerTest : public SocketTest {
public:
std::int32_t n_workers{2};
std::chrono::seconds timeout{1};
std::string host;
void SetUp() override {
SocketTest::SetUp();
auto rc = GetHostAddress(&host);
ASSERT_TRUE(rc.OK()) << rc.Report();
}
};
} // namespace xgboost::collective