enable ROCm on latest XGBoost

This commit is contained in:
Hui Liu
2023-10-23 11:07:08 -07:00
328 changed files with 8028 additions and 3642 deletions

View File

@@ -15,8 +15,7 @@
#include "communicator-inl.cuh"
namespace xgboost {
namespace collective {
namespace xgboost::collective {
/**
* @brief Find the global sum of the given values across all workers.
@@ -31,10 +30,9 @@ namespace collective {
* @param size Number of values to sum.
*/
template <typename T>
void GlobalSum(MetaInfo const& info, int device, T* values, size_t size) {
void GlobalSum(MetaInfo const& info, DeviceOrd device, T* values, size_t size) {
if (info.IsRowSplit()) {
collective::AllReduce<collective::Operation::kSum>(device, values, size);
collective::AllReduce<collective::Operation::kSum>(device.ordinal, values, size);
}
}
} // namespace collective
} // namespace xgboost
} // namespace xgboost::collective

View File

@@ -0,0 +1,88 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include "allgather.h"
#include <algorithm> // for min, copy_n
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int32_t, int64_t
#include <memory> // for shared_ptr
#include <numeric> // for partial_sum
#include <vector> // for vector
#include "comm.h" // for Comm, Channel
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective::cpu_impl {
Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size_t segment_size,
std::int32_t worker_off, std::shared_ptr<Channel> prev_ch,
std::shared_ptr<Channel> next_ch) {
auto world = comm.World();
auto rank = comm.Rank();
CHECK_LT(worker_off, world);
for (std::int32_t r = 0; r < world; ++r) {
auto send_rank = (rank + world - r + worker_off) % world;
auto send_off = send_rank * segment_size;
send_off = std::min(send_off, data.size_bytes());
auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off));
next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
auto recv_off = recv_rank * segment_size;
recv_off = std::min(recv_off, data.size_bytes());
auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off));
prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
auto rc = prev_ch->Block();
if (!rc.OK()) {
return rc;
}
}
return Success();
}
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t const> data,
common::Span<std::int8_t> erased_result) {
auto world = comm.World();
auto rank = comm.Rank();
auto prev = BootstrapPrev(rank, comm.World());
auto next = BootstrapNext(rank, comm.World());
auto prev_ch = comm.Chan(prev);
auto next_ch = comm.Chan(next);
// get worker offset
std::vector<std::int64_t> offset(world + 1, 0);
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
CHECK_EQ(*offset.cbegin(), 0);
// copy data
auto current = erased_result.subspan(offset[rank], data.size_bytes());
auto erased_data = EraseType(data);
std::copy_n(erased_data.data(), erased_data.size(), current.data());
for (std::int32_t r = 0; r < world; ++r) {
auto send_rank = (rank + world - r) % world;
auto send_off = offset[send_rank];
auto send_size = sizes[send_rank];
auto send_seg = erased_result.subspan(send_off, send_size);
next_ch->SendAll(send_seg);
auto recv_rank = (rank + world - r - 1) % world;
auto recv_off = offset[recv_rank];
auto recv_size = sizes[recv_rank];
auto recv_seg = erased_result.subspan(recv_off, recv_size);
prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
auto rc = prev_ch->Block();
if (!rc.OK()) {
return rc;
}
}
return comm.Block();
}
} // namespace xgboost::collective::cpu_impl

View File

@@ -0,0 +1,72 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <memory> // for shared_ptr
#include <numeric> // for accumulate
#include <type_traits> // for remove_cv_t
#include <vector> // for vector
#include "comm.h" // for Comm, Channel, EraseType
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
namespace cpu_impl {
/**
* @param worker_off Segment offset. For example, if the rank 2 worker specifis worker_off
* = 1, then it owns the third segment.
*/
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data,
std::size_t segment_size, std::int32_t worker_off,
std::shared_ptr<Channel> prev_ch,
std::shared_ptr<Channel> next_ch);
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t const> data,
common::Span<std::int8_t> erased_result);
} // namespace cpu_impl
template <typename T>
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
auto n_bytes = sizeof(T) * size;
auto erased = EraseType(data);
auto rank = comm.Rank();
auto prev = BootstrapPrev(rank, comm.World());
auto next = BootstrapNext(rank, comm.World());
auto prev_ch = comm.Chan(prev);
auto next_ch = comm.Chan(next);
auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes, 0, prev_ch, next_ch);
if (!rc.OK()) {
return rc;
}
return comm.Block();
}
template <typename T>
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<T> data,
std::vector<std::remove_cv_t<T>>* p_out) {
auto world = comm.World();
auto rank = comm.Rank();
std::vector<std::int64_t> sizes(world, 0);
sizes[rank] = data.size_bytes();
auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()}, 1);
if (!rc.OK()) {
return rc;
}
std::vector<T>& result = *p_out;
auto n_total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
result.resize(n_total_bytes / sizeof(T));
auto h_result = common::Span{result.data(), result.size()};
auto erased_result = EraseType(h_result);
auto erased_data = EraseType(data);
return cpu_impl::RingAllgatherV(comm, sizes, erased_data, erased_result);
}
} // namespace xgboost::collective

View File

@@ -0,0 +1,90 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include "allreduce.h"
#include <algorithm> // for min
#include <cstddef> // for size_t
#include <cstdint> // for int32_t, int8_t
#include <vector> // for vector
#include "../data/array_interface.h" // for Type, DispatchDType
#include "allgather.h" // for RingAllgather
#include "comm.h" // for Comm
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective::cpu_impl {
template <typename T>
Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
std::size_t n_bytes_in_seg, Func const& op) {
auto rank = comm.Rank();
auto world = comm.World();
auto dst_rank = BootstrapNext(rank, world);
auto src_rank = BootstrapPrev(rank, world);
auto next_ch = comm.Chan(dst_rank);
auto prev_ch = comm.Chan(src_rank);
std::vector<std::int8_t> buffer(n_bytes_in_seg, 0);
auto s_buf = common::Span{buffer.data(), buffer.size()};
for (std::int32_t r = 0; r < world - 1; ++r) {
// send to ring next
auto send_off = ((rank + world - r) % world) * n_bytes_in_seg;
send_off = std::min(send_off, data.size_bytes());
auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg);
auto send_seg = data.subspan(send_off, seg_nbytes);
next_ch->SendAll(send_seg);
// receive from ring prev
auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg;
recv_off = std::min(recv_off, data.size_bytes());
seg_nbytes = std::min(data.size_bytes() - recv_off, n_bytes_in_seg);
CHECK_EQ(seg_nbytes % sizeof(T), 0);
auto recv_seg = data.subspan(recv_off, seg_nbytes);
auto seg = s_buf.subspan(0, recv_seg.size());
prev_ch->RecvAll(seg);
auto rc = prev_ch->Block();
if (!rc.OK()) {
return rc;
}
// accumulate to recv_seg
CHECK_EQ(seg.size(), recv_seg.size());
op(seg, recv_seg);
}
return Success();
}
Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func const& op,
ArrayInterfaceHandler::Type type) {
return DispatchDType(type, [&](auto t) {
using T = decltype(t);
// Divide the data into segments according to the number of workers.
auto n_bytes_elem = sizeof(T);
CHECK_EQ(data.size_bytes() % n_bytes_elem, 0);
auto n = data.size_bytes() / n_bytes_elem;
auto world = comm.World();
auto n_bytes_in_seg = common::DivRoundUp(n, world) * sizeof(T);
auto rc = RingScatterReduceTyped<T>(comm, data, n_bytes_in_seg, op);
if (!rc.OK()) {
return rc;
}
auto prev = BootstrapPrev(comm.Rank(), comm.World());
auto next = BootstrapNext(comm.Rank(), comm.World());
auto prev_ch = comm.Chan(prev);
auto next_ch = comm.Chan(next);
rc = RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
if (!rc.OK()) {
return rc;
}
return comm.Block();
});
}
} // namespace xgboost::collective::cpu_impl

View File

@@ -0,0 +1,39 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int8_t
#include <functional> // for function
#include <type_traits> // for is_invocable_v
#include "../data/array_interface.h" // for ArrayInterfaceHandler
#include "comm.h" // for Comm, RestoreType
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
namespace cpu_impl {
using Func =
std::function<void(common::Span<std::int8_t const> lhs, common::Span<std::int8_t> out)>;
Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func const& op,
ArrayInterfaceHandler::Type type);
} // namespace cpu_impl
template <typename T, typename Fn>
std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>, Result> Allreduce(
Comm const& comm, common::Span<T> data, Fn redop) {
auto erased = EraseType(data);
auto type = ToDType<T>::kType;
auto erased_fn = [type, redop](common::Span<std::int8_t const> lhs,
common::Span<std::int8_t> out) {
CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction.";
auto lhs_t = RestoreType<T const>(lhs);
auto rhs_t = RestoreType<T>(out);
redop(lhs_t, rhs_t);
};
return cpu_impl::RingAllreduce(comm, erased, erased_fn, type);
}
} // namespace xgboost::collective

View File

@@ -0,0 +1,84 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include "broadcast.h"
#include <cmath> // for ceil, log2
#include <cstdint> // for int32_t, int8_t
#include <utility> // for move
#include "../common/bitfield.h" // for TrailingZeroBits, RBitField32
#include "comm.h" // for Comm
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective::cpu_impl {
namespace {
std::int32_t ShiftedParentRank(std::int32_t shifted_rank, std::int32_t depth) {
std::uint32_t mask{std::uint32_t{0} - 1}; // Oxff...
RBitField32 maskbits{common::Span<std::uint32_t>{&mask, 1}};
RBitField32 rankbits{
common::Span<std::uint32_t>{reinterpret_cast<std::uint32_t*>(&shifted_rank), 1}};
// prepare for counting trailing zeros.
for (std::int32_t i = 0; i < depth + 1; ++i) {
if (rankbits.Check(i)) {
maskbits.Set(i);
} else {
maskbits.Clear(i);
}
}
CHECK_NE(mask, 0);
auto k = TrailingZeroBits(mask);
auto shifted_parent = shifted_rank - (1 << k);
return shifted_parent;
}
// Shift the root node to rank 0
std::int32_t ShiftLeft(std::int32_t rank, std::int32_t world, std::int32_t root) {
auto shifted_rank = (rank + world - root) % world;
return shifted_rank;
}
// shift back to the original rank
std::int32_t ShiftRight(std::int32_t rank, std::int32_t world, std::int32_t root) {
auto orig = (rank + root) % world;
return orig;
}
} // namespace
Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t root) {
// Binomial tree broadcast
// * Wiki
// https://en.wikipedia.org/wiki/Broadcast_(parallel_pattern)#Binomial_Tree_Broadcast
// * Impl
// https://people.mpi-inf.mpg.de/~mehlhorn/ftp/NewToolbox/collective.pdf
auto rank = comm.Rank();
auto world = comm.World();
// shift root to rank 0
auto shifted_rank = ShiftLeft(rank, world, root);
std::int32_t depth = std::ceil(std::log2(static_cast<double>(world))) - 1;
if (shifted_rank != 0) { // not root
auto parent = ShiftRight(ShiftedParentRank(shifted_rank, depth), world, root);
comm.Chan(parent)->RecvAll(data);
auto rc = comm.Chan(parent)->Block();
if (!rc.OK()) {
return Fail("broadcast failed.", std::move(rc));
}
}
for (std::int32_t i = depth; i >= 0; --i) {
CHECK_GE((i + 1), 0); // weird clang-tidy error that i might be negative
if (shifted_rank % (1 << (i + 1)) == 0 && shifted_rank + (1 << i) < world) {
auto sft_peer = shifted_rank + (1 << i);
auto peer = ShiftRight(sft_peer, world, root);
CHECK_NE(peer, root);
comm.Chan(peer)->SendAll(data);
}
}
return comm.Block();
}
} // namespace xgboost::collective::cpu_impl

View File

@@ -0,0 +1,26 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int32_t, int8_t
#include "comm.h" // for Comm
#include "xgboost/collective/result.h" // for
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
namespace cpu_impl {
Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t root);
}
/**
* @brief binomial tree broadcast is used on CPU with the default implementation.
*/
template <typename T>
[[nodiscard]] Result Broadcast(Comm const& comm, common::Span<T> data, std::int32_t root) {
auto n_total_bytes = data.size_bytes();
auto erased =
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes};
return cpu_impl::Broadcast(comm, erased, root);
}
} // namespace xgboost::collective

304
src/collective/comm.cc Normal file
View File

@@ -0,0 +1,304 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include "comm.h"
#include <algorithm> // for copy
#include <chrono> // for seconds
#include <memory> // for shared_ptr
#include <string> // for string
#include <utility> // for move, forward
#include "allgather.h"
#include "protocol.h" // for kMagic
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
#include "xgboost/collective/socket.h" // for TCPSocket
#include "xgboost/json.h" // for Json, Object
#include "xgboost/string_view.h" // for StringView
namespace xgboost::collective {
Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t retry, std::string task_id)
: timeout_{timeout},
retry_{retry},
tracker_{host, port, -1},
task_id_{std::move(task_id)},
loop_{std::make_shared<Loop>(timeout)} {}
Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
std::string const& task_id, TCPSocket* out, std::int32_t rank,
std::int32_t world) {
// get information from tracker
CHECK(!info.host.empty());
auto rc = Connect(info.host, info.port, retry, timeout, out);
if (!rc.OK()) {
return Fail("Failed to connect to the tracker.", std::move(rc));
}
TCPSocket& tracker = *out;
return std::move(rc)
<< [&] { return tracker.NonBlocking(false); }
<< [&] { return tracker.RecvTimeout(timeout); }
<< [&] { return proto::Magic{}.Verify(&tracker); }
<< [&] { return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id); };
}
[[nodiscard]] Result Comm::ConnectTracker(TCPSocket* out) const {
return ConnectTrackerImpl(this->TrackerInfo(), this->Timeout(), this->retry_, this->task_id_, out,
this->Rank(), this->World());
}
[[nodiscard]] Result ConnectWorkers(Comm const& comm, TCPSocket* listener, std::int32_t lport,
proto::PeerInfo ninfo, std::chrono::seconds timeout,
std::int32_t retry,
std::vector<std::shared_ptr<TCPSocket>>* out_workers) {
auto next = std::make_shared<TCPSocket>();
auto prev = std::make_shared<TCPSocket>();
auto rc = Success() << [&] {
auto rc = Connect(ninfo.host, ninfo.port, retry, timeout, next.get());
if (!rc.OK()) {
return Fail("Bootstrap failed to connect to ring next.", std::move(rc));
}
return rc;
} << [&] {
return next->NonBlocking(true);
} << [&] {
SockAddrV4 addr;
return listener->Accept(prev.get(), &addr);
} << [&] { return prev->NonBlocking(true); };
if (!rc.OK()) {
return rc;
}
// exchange host name and port
std::vector<std::int8_t> buffer(HOST_NAME_MAX * comm.World(), 0);
auto s_buffer = common::Span{buffer.data(), buffer.size()};
auto next_host = s_buffer.subspan(HOST_NAME_MAX * comm.Rank(), HOST_NAME_MAX);
if (next_host.size() < ninfo.host.size()) {
return Fail("Got an invalid host name.");
}
std::copy(ninfo.host.cbegin(), ninfo.host.cend(), next_host.begin());
auto prev_ch = std::make_shared<Channel>(comm, prev);
auto next_ch = std::make_shared<Channel>(comm, next);
auto block = [&] {
for (auto ch : {prev_ch, next_ch}) {
auto rc = ch->Block();
if (!rc.OK()) {
return rc;
}
}
return Success();
};
rc = std::move(rc) << [&] {
return cpu_impl::RingAllgather(comm, s_buffer, HOST_NAME_MAX, 0, prev_ch, next_ch);
} << [&] { return block(); };
if (!rc.OK()) {
return Fail("Failed to get host names from peers.", std::move(rc));
}
std::vector<std::int32_t> peers_port(comm.World(), -1);
peers_port[comm.Rank()] = ninfo.port;
rc = std::move(rc) << [&] {
auto s_ports = common::Span{reinterpret_cast<std::int8_t*>(peers_port.data()),
peers_port.size() * sizeof(ninfo.port)};
return cpu_impl::RingAllgather(comm, s_ports, sizeof(ninfo.port), 0, prev_ch, next_ch);
} << [&] { return block(); };
if (!rc.OK()) {
return Fail("Failed to get the port from peers.", std::move(rc));
}
std::vector<proto::PeerInfo> peers(comm.World());
for (auto r = 0; r < comm.World(); ++r) {
auto nhost = s_buffer.subspan(HOST_NAME_MAX * r, HOST_NAME_MAX);
auto nport = peers_port[r];
auto nrank = BootstrapNext(r, comm.World());
peers[nrank] = {std::string{reinterpret_cast<char const*>(nhost.data())}, nport, nrank};
}
CHECK_EQ(peers[comm.Rank()].port, lport);
for (auto const& p : peers) {
CHECK_NE(p.port, -1);
}
std::vector<std::shared_ptr<TCPSocket>>& workers = *out_workers;
workers.resize(comm.World());
for (std::int32_t r = (comm.Rank() + 1); r < comm.World(); ++r) {
auto const& peer = peers[r];
std::shared_ptr<TCPSocket> worker{TCPSocket::CreatePtr(comm.Domain())};
rc = std::move(rc)
<< [&] { return Connect(peer.host, peer.port, retry, timeout, worker.get()); }
<< [&] { return worker->RecvTimeout(timeout); };
if (!rc.OK()) {
return rc;
}
auto rank = comm.Rank();
auto n_bytes = worker->SendAll(&rank, sizeof(comm.Rank()));
if (n_bytes != sizeof(comm.Rank())) {
return Fail("Failed to send rank.");
}
workers[r] = std::move(worker);
}
for (std::int32_t r = 0; r < comm.Rank(); ++r) {
SockAddrV4 addr;
auto peer = std::shared_ptr<TCPSocket>(TCPSocket::CreatePtr(comm.Domain()));
rc = std::move(rc) << [&] { return listener->Accept(peer.get(), &addr); }
<< [&] { return peer->RecvTimeout(timeout); };
if (!rc.OK()) {
return rc;
}
std::int32_t rank{-1};
auto n_bytes = peer->RecvAll(&rank, sizeof(rank));
if (n_bytes != sizeof(comm.Rank())) {
return Fail("Failed to recv rank.");
}
workers[rank] = std::move(peer);
}
for (std::int32_t r = 0; r < comm.World(); ++r) {
if (r == comm.Rank()) {
continue;
}
CHECK(workers[r]);
}
return Success();
}
RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t retry, std::string task_id)
: Comm{std::move(host), port, timeout, retry, std::move(task_id)} {
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
CHECK(rc.OK()) << rc.Report();
}
[[nodiscard]] Result RabitComm::Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
std::string task_id) {
TCPSocket tracker;
std::int32_t world{-1};
auto rc = ConnectTrackerImpl(this->TrackerInfo(), timeout, retry, task_id, &tracker, this->Rank(),
world);
if (!rc.OK()) {
return Fail("Bootstrap failed.", std::move(rc));
}
this->domain_ = tracker.Domain();
// Start command
TCPSocket listener = TCPSocket::Create(tracker.Domain());
std::int32_t lport = listener.BindHost();
listener.Listen();
// create worker for listening to error notice.
auto domain = tracker.Domain();
std::shared_ptr<TCPSocket> error_sock{TCPSocket::CreatePtr(domain)};
auto eport = error_sock->BindHost();
error_sock->Listen();
error_worker_ = std::thread{[this, error_sock = std::move(error_sock)] {
auto conn = error_sock->Accept();
// On Windows accept returns an invalid socket after network is shutdown.
if (conn.IsClosed()) {
return;
}
LOG(WARNING) << "Another worker is running into error.";
std::string scmd;
conn.Recv(&scmd);
auto jcmd = Json::Load(scmd);
auto rc = this->Shutdown();
if (!rc.OK()) {
LOG(WARNING) << "Fail to shutdown worker:" << rc.Report();
}
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
exit(-1);
#else
LOG(FATAL) << rc.Report();
#endif
}};
error_worker_.detach();
proto::Start start;
rc = std::move(rc) << [&] { return start.WorkerSend(lport, &tracker, eport); }
<< [&] { return start.WorkerRecv(&tracker, &world); };
if (!rc.OK()) {
return rc;
}
this->world_ = world;
// get ring neighbors
std::string snext;
tracker.Recv(&snext);
auto jnext = Json::Load(StringView{snext});
proto::PeerInfo ninfo{jnext};
// get the rank of this worker
this->rank_ = BootstrapPrev(ninfo.rank, world);
this->tracker_.rank = rank_;
std::vector<std::shared_ptr<TCPSocket>> workers;
rc = ConnectWorkers(*this, &listener, lport, ninfo, timeout, retry, &workers);
if (!rc.OK()) {
return rc;
}
CHECK(this->channels_.empty());
for (auto& w : workers) {
if (w) {
w->SetNoDelay();
rc = w->NonBlocking(true);
}
if (!rc.OK()) {
return rc;
}
this->channels_.emplace_back(std::make_shared<Channel>(*this, w));
}
return rc;
}
RabitComm::~RabitComm() noexcept(false) {
if (!IsDistributed()) {
return;
}
auto rc = this->Shutdown();
if (!rc.OK()) {
LOG(WARNING) << rc.Report();
}
}
[[nodiscard]] Result RabitComm::Shutdown() {
TCPSocket tracker;
return Success() << [&] {
return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World());
} << [&] {
return this->Block();
} << [&] {
Json jcmd{Object{}};
jcmd["cmd"] = Integer{static_cast<std::int32_t>(proto::CMD::kShutdown)};
auto scmd = Json::Dump(jcmd);
auto n_bytes = tracker.Send(scmd);
if (n_bytes != scmd.size()) {
return Fail("Faled to send cmd.");
}
return Success();
};
}
[[nodiscard]] Result RabitComm::LogTracker(std::string msg) const {
TCPSocket out;
proto::Print print;
return Success() << [&] { return this->ConnectTracker(&out); }
<< [&] { return print.WorkerSend(&out, msg); };
}
[[nodiscard]] Result RabitComm::SignalError(Result const& res) {
TCPSocket out;
return Success() << [&] { return this->ConnectTracker(&out); }
<< [&] { return proto::ErrorCMD{}.WorkerSend(&out, res); };
}
} // namespace xgboost::collective

156
src/collective/comm.h Normal file
View File

@@ -0,0 +1,156 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <chrono> // for seconds
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <memory> // for shared_ptr
#include <string> // for string
#include <thread> // for thread
#include <type_traits> // for remove_const_t
#include <utility> // for move
#include <vector> // for vector
#include "loop.h" // for Loop
#include "protocol.h" // for PeerInfo
#include "xgboost/collective/result.h" // for Result
#include "xgboost/collective/socket.h" // for TCPSocket
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
inline constexpr std::int32_t DefaultTimeoutSec() { return 300; } // 5min
inline constexpr std::int32_t DefaultRetry() { return 3; }
// indexing into the ring
inline std::int32_t BootstrapNext(std::int32_t r, std::int32_t world) {
auto nrank = (r + world + 1) % world;
return nrank;
}
inline std::int32_t BootstrapPrev(std::int32_t r, std::int32_t world) {
auto nrank = (r + world - 1) % world;
return nrank;
}
class Channel;
/**
* @brief Base communicator storing info about the tracker and other communicators.
*/
class Comm {
protected:
std::int32_t world_{1};
std::int32_t rank_{0};
std::chrono::seconds timeout_{DefaultTimeoutSec()};
std::int32_t retry_{DefaultRetry()};
proto::PeerInfo tracker_;
SockDomain domain_{SockDomain::kV4};
std::thread error_worker_;
std::string task_id_;
std::vector<std::shared_ptr<Channel>> channels_;
std::shared_ptr<Loop> loop_{new Loop{std::chrono::seconds{
DefaultTimeoutSec()}}}; // fixme: require federated comm to have a timeout
public:
Comm() = default;
Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, std::int32_t retry,
std::string task_id);
virtual ~Comm() noexcept(false) {} // NOLINT
Comm(Comm const& that) = delete;
Comm& operator=(Comm const& that) = delete;
Comm(Comm&& that) = delete;
Comm& operator=(Comm&& that) = delete;
[[nodiscard]] auto TrackerInfo() const { return tracker_; }
[[nodiscard]] Result ConnectTracker(TCPSocket* out) const;
[[nodiscard]] auto Domain() const { return domain_; }
[[nodiscard]] auto Timeout() const { return timeout_; }
[[nodiscard]] auto Rank() const { return rank_; }
[[nodiscard]] auto World() const { return world_; }
[[nodiscard]] bool IsDistributed() const { return World() > 1; }
void Submit(Loop::Op op) const { loop_->Submit(op); }
[[nodiscard]] Result Block() const { return loop_->Block(); }
[[nodiscard]] virtual std::shared_ptr<Channel> Chan(std::int32_t rank) const {
return channels_.at(rank);
}
[[nodiscard]] virtual bool IsFederated() const = 0;
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
};
class RabitComm : public Comm {
[[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
std::string task_id);
[[nodiscard]] Result Shutdown();
public:
// bootstrapping construction.
RabitComm() = default;
// ctor for testing where environment is known.
RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t retry, std::string task_id);
~RabitComm() noexcept(false) override;
[[nodiscard]] bool IsFederated() const override { return false; }
[[nodiscard]] Result LogTracker(std::string msg) const override;
[[nodiscard]] Result SignalError(Result const&) override;
};
/**
* @brief Communication channel between workers.
*/
class Channel {
std::shared_ptr<TCPSocket> sock_{nullptr};
Result rc_;
Comm const& comm_;
public:
explicit Channel(Comm const& comm, std::shared_ptr<TCPSocket> sock)
: sock_{std::move(sock)}, comm_{comm} {}
void SendAll(std::int8_t const* ptr, std::size_t n) {
Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast<std::int8_t*>(ptr), n, sock_.get(), 0};
CHECK(sock_.get());
comm_.Submit(std::move(op));
}
void SendAll(common::Span<std::int8_t const> data) {
this->SendAll(data.data(), data.size_bytes());
}
void RecvAll(std::int8_t* ptr, std::size_t n) {
Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0};
CHECK(sock_.get());
comm_.Submit(std::move(op));
}
void RecvAll(common::Span<std::int8_t> data) { this->RecvAll(data.data(), data.size_bytes()); }
[[nodiscard]] auto Socket() const { return sock_; }
[[nodiscard]] Result Block() { return comm_.Block(); }
};
enum class Op { kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5 };
template <typename T, typename U = std::conditional_t<std::is_const_v<T>,
std::add_const_t<std::int8_t>, std::int8_t>>
common::Span<U> EraseType(common::Span<T> data) {
auto n_total_bytes = data.size_bytes();
auto erased = common::Span{reinterpret_cast<std::add_pointer_t<U>>(data.data()), n_total_bytes};
return erased;
}
template <typename T, typename U>
common::Span<T> RestoreType(common::Span<U> data) {
static_assert(std::is_same_v<std::remove_const_t<U>, std::int8_t>);
auto n_total_bytes = data.size_bytes();
auto restored = common::Span{reinterpret_cast<T*>(data.data()), n_total_bytes / sizeof(T)};
return restored;
}
} // namespace xgboost::collective

View File

@@ -57,9 +57,7 @@ namespace collective {
* - federated_client_key: Client key file path. Only needed for the SSL mode.
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
*/
inline void Init(Json const& config) {
Communicator::Init(config);
}
inline void Init(Json const &config) { Communicator::Init(config); }
/*!
* \brief Finalize the collective communicator.
@@ -141,17 +139,89 @@ inline void Broadcast(std::string *sendrecv_data, int root) {
}
}
/**
* @brief Gathers a single value all processes and distributes the result to all processes.
*
* @param input The single value.
*/
template <typename T>
inline std::vector<T> Allgather(T const &input) {
std::string_view str_input{reinterpret_cast<char const *>(&input), sizeof(T)};
auto const output = Communicator::Get()->AllGather(str_input);
CHECK_EQ(output.size() % sizeof(T), 0);
std::vector<T> result(output.size() / sizeof(T));
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
return result;
}
/**
* @brief Gathers data from all processes and distributes it to all processes.
*
* This assumes all ranks have the same size, and input data has been sliced into the
* corresponding position.
* This assumes all ranks have the same size.
*
* @param send_receive_buffer Buffer storing the data.
* @param size Size of the data in bytes.
* @param input Buffer storing the data.
*/
inline void Allgather(void *send_receive_buffer, std::size_t size) {
Communicator::Get()->AllGather(send_receive_buffer, size);
template <typename T>
inline std::vector<T> Allgather(std::vector<T> const &input) {
if (input.empty()) {
return input;
}
std::string_view str_input{reinterpret_cast<char const *>(input.data()),
input.size() * sizeof(T)};
auto const output = Communicator::Get()->AllGather(str_input);
CHECK_EQ(output.size() % sizeof(T), 0);
std::vector<T> result(output.size() / sizeof(T));
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
return result;
}
/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
* @param input Buffer storing the data.
*/
template <typename T>
inline std::vector<T> AllgatherV(std::vector<T> const &input) {
std::string_view str_input{reinterpret_cast<char const *>(input.data()),
input.size() * sizeof(T)};
auto const output = Communicator::Get()->AllGatherV(str_input);
CHECK_EQ(output.size() % sizeof(T), 0);
std::vector<T> result(output.size() / sizeof(T));
if (!output.empty()) {
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
}
return result;
}
/**
* @brief Gathers variable-length strings from all processes and distributes them to all processes.
* @param input Variable-length list of variable-length strings.
*/
inline std::vector<std::string> AllgatherStrings(std::vector<std::string> const &input) {
std::size_t total_size{0};
for (auto const &s : input) {
total_size += s.length() + 1; // +1 for null-terminators
}
std::string flat_string;
flat_string.reserve(total_size);
for (auto const &s : input) {
flat_string.append(s);
flat_string.push_back('\0'); // Append a null-terminator after each string
}
auto const output = Communicator::Get()->AllGatherV(flat_string);
std::vector<std::string> result;
std::size_t start_index = 0;
// Iterate through the output, find each null-terminated substring.
for (std::size_t i = 0; i < output.size(); i++) {
if (output[i] == '\0') {
// Construct a std::string from the char* substring
result.emplace_back(&output[start_index]);
// Move to the next substring
start_index = i + 1;
}
}
return result;
}
/*!
@@ -226,7 +296,7 @@ inline void Allreduce(double *send_receive_buffer, size_t count) {
}
template <typename T>
struct AllgatherVResult {
struct SpecialAllgatherVResult {
std::vector<std::size_t> offsets;
std::vector<std::size_t> sizes;
std::vector<T> result;
@@ -241,14 +311,10 @@ struct AllgatherVResult {
* @param sizes Sizes of each input.
*/
template <typename T>
inline AllgatherVResult<T> AllgatherV(std::vector<T> const &inputs,
std::vector<std::size_t> const &sizes) {
auto num_inputs = sizes.size();
inline SpecialAllgatherVResult<T> SpecialAllgatherV(std::vector<T> const &inputs,
std::vector<std::size_t> const &sizes) {
// Gather the sizes across all workers.
std::vector<std::size_t> all_sizes(num_inputs * GetWorldSize());
std::copy_n(sizes.cbegin(), sizes.size(), all_sizes.begin() + num_inputs * GetRank());
collective::Allgather(all_sizes.data(), all_sizes.size() * sizeof(std::size_t));
auto const all_sizes = Allgather(sizes);
// Calculate input offsets (std::exclusive_scan).
std::vector<std::size_t> offsets(all_sizes.size());
@@ -257,11 +323,7 @@ inline AllgatherVResult<T> AllgatherV(std::vector<T> const &inputs,
}
// Gather all the inputs.
auto total_input_size = offsets.back() + all_sizes.back();
std::vector<T> all_inputs(total_input_size);
std::copy_n(inputs.cbegin(), inputs.size(), all_inputs.begin() + offsets[num_inputs * GetRank()]);
// We cannot use allgather here, since each worker might have a different size.
Allreduce<Operation::kMax>(all_inputs.data(), all_inputs.size());
auto const all_inputs = AllgatherV(inputs);
return {offsets, all_sizes, all_inputs};
}

View File

@@ -11,9 +11,7 @@
#include "../../plugin/federated/federated_communicator.h"
#endif
namespace xgboost {
namespace collective {
namespace xgboost::collective {
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
thread_local CommunicatorType Communicator::type_{};
@@ -57,6 +55,4 @@ void Communicator::Finalize() {
communicator_.reset(new NoOpCommunicator());
}
#endif
} // namespace collective
} // namespace xgboost
} // namespace xgboost::collective

View File

@@ -125,13 +125,17 @@ class Communicator {
/**
* @brief Gathers data from all processes and distributes it to all processes.
*
* This assumes all ranks have the same size, and input data has been sliced into the
* corresponding position.
* This assumes all ranks have the same size.
*
* @param send_receive_buffer Buffer storing the data.
* @param size Size of the data in bytes.
* @param input Buffer storing the data.
*/
virtual void AllGather(void *send_receive_buffer, std::size_t size) = 0;
virtual std::string AllGather(std::string_view input) = 0;
/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
* @param input Buffer storing the data.
*/
virtual std::string AllGatherV(std::string_view input) = 0;
/**
* @brief Combines values from all processes and distributes the result back to all processes.

View File

@@ -40,12 +40,10 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
host_buffer_.resize(send_size * world_size_);
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + rank_ * send_size, send_buffer, send_size,
cudaMemcpyDefault));
Allgather(host_buffer_.data(), host_buffer_.size());
dh::safe_cuda(
cudaMemcpy(receive_buffer, host_buffer_.data(), host_buffer_.size(), cudaMemcpyDefault));
host_buffer_.resize(send_size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_buffer, send_size, cudaMemcpyDefault));
auto const output = Allgather(host_buffer_);
dh::safe_cuda(cudaMemcpy(receive_buffer, output.data(), output.size(), cudaMemcpyDefault));
}
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,

View File

@@ -60,11 +60,16 @@ class InMemoryCommunicator : public Communicator {
bool IsDistributed() const override { return true; }
bool IsFederated() const override { return false; }
void AllGather(void* in_out, std::size_t size) override {
std::string AllGather(std::string_view input) override {
std::string output;
handler_.Allgather(static_cast<const char*>(in_out), size, &output, sequence_number_++,
GetRank());
output.copy(static_cast<char*>(in_out), size);
handler_.Allgather(input.data(), input.size(), &output, sequence_number_++, GetRank());
return output;
}
std::string AllGatherV(std::string_view input) override {
std::string output;
handler_.AllgatherV(input.data(), input.size(), &output, sequence_number_++, GetRank());
return output;
}
void AllReduce(void* in_out, std::size_t size, DataType data_type, Operation operation) override {

View File

@@ -16,23 +16,49 @@ class AllgatherFunctor {
public:
std::string const name{"Allgather"};
AllgatherFunctor(int world_size, int rank) : world_size_{world_size}, rank_{rank} {}
AllgatherFunctor(std::size_t world_size, std::size_t rank)
: world_size_{world_size}, rank_{rank} {}
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
if (buffer->empty()) {
// Copy the input if this is the first request.
buffer->assign(input, bytes);
} else {
// Splice the input into the common buffer.
auto const per_rank = bytes / world_size_;
auto const index = rank_ * per_rank;
buffer->replace(index, per_rank, input + index, per_rank);
// Resize the buffer if this is the first request.
buffer->resize(bytes * world_size_);
}
// Splice the input into the common buffer.
buffer->replace(rank_ * bytes, bytes, input, bytes);
}
private:
std::size_t world_size_;
std::size_t rank_;
};
/**
* @brief Functor for variable-length allgather.
*/
class AllgatherVFunctor {
public:
std::string const name{"AllgatherV"};
AllgatherVFunctor(std::size_t world_size, std::size_t rank,
std::map<std::size_t, std::string_view>* data)
: world_size_{world_size}, rank_{rank}, data_{data} {}
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
data_->emplace(rank_, std::string_view{input, bytes});
if (data_->size() == world_size_) {
for (auto const& kv : *data_) {
buffer->append(kv.second);
}
data_->clear();
}
}
private:
int world_size_;
int rank_;
std::size_t world_size_;
std::size_t rank_;
std::map<std::size_t, std::string_view>* data_;
};
/**
@@ -154,7 +180,7 @@ class BroadcastFunctor {
public:
std::string const name{"Broadcast"};
BroadcastFunctor(int rank, int root) : rank_{rank}, root_{root} {}
BroadcastFunctor(std::size_t rank, std::size_t root) : rank_{rank}, root_{root} {}
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
if (rank_ == root_) {
@@ -164,11 +190,11 @@ class BroadcastFunctor {
}
private:
int rank_;
int root_;
std::size_t rank_;
std::size_t root_;
};
void InMemoryHandler::Init(int world_size, int) {
void InMemoryHandler::Init(std::size_t world_size, std::size_t) {
CHECK(world_size_ < world_size) << "In memory handler already initialized.";
std::unique_lock<std::mutex> lock(mutex_);
@@ -178,7 +204,7 @@ void InMemoryHandler::Init(int world_size, int) {
cv_.notify_all();
}
void InMemoryHandler::Shutdown(uint64_t sequence_number, int) {
void InMemoryHandler::Shutdown(uint64_t sequence_number, std::size_t) {
CHECK(world_size_ > 0) << "In memory handler already shutdown.";
std::unique_lock<std::mutex> lock(mutex_);
@@ -194,24 +220,30 @@ void InMemoryHandler::Shutdown(uint64_t sequence_number, int) {
}
void InMemoryHandler::Allgather(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, int rank) {
std::size_t sequence_number, std::size_t rank) {
Handle(input, bytes, output, sequence_number, rank, AllgatherFunctor{world_size_, rank});
}
void InMemoryHandler::AllgatherV(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank) {
Handle(input, bytes, output, sequence_number, rank, AllgatherVFunctor{world_size_, rank, &aux_});
}
void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, int rank, DataType data_type,
std::size_t sequence_number, std::size_t rank, DataType data_type,
Operation op) {
Handle(input, bytes, output, sequence_number, rank, AllreduceFunctor{data_type, op});
}
void InMemoryHandler::Broadcast(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, int rank, int root) {
std::size_t sequence_number, std::size_t rank, std::size_t root) {
Handle(input, bytes, output, sequence_number, rank, BroadcastFunctor{rank, root});
}
template <class HandlerFunctor>
void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, int rank, HandlerFunctor const& functor) {
std::size_t sequence_number, std::size_t rank,
HandlerFunctor const& functor) {
// Pass through if there is only 1 client.
if (world_size_ == 1) {
if (input != output->data()) {

View File

@@ -3,6 +3,7 @@
*/
#pragma once
#include <condition_variable>
#include <map>
#include <string>
#include "communicator.h"
@@ -31,7 +32,7 @@ class InMemoryHandler {
*
* This is used when the handler only needs to be initialized once with a known world size.
*/
explicit InMemoryHandler(int worldSize) : world_size_{worldSize} {}
explicit InMemoryHandler(std::size_t worldSize) : world_size_{worldSize} {}
/**
* @brief Initialize the handler with the world size and rank.
@@ -41,7 +42,7 @@ class InMemoryHandler {
* This is used when multiple objects/threads are accessing the same handler and need to
* initialize it collectively.
*/
void Init(int world_size, int rank);
void Init(std::size_t world_size, std::size_t rank);
/**
* @brief Shut down the handler.
@@ -51,7 +52,7 @@ class InMemoryHandler {
* This is used when multiple objects/threads are accessing the same handler and need to
* shut it down collectively.
*/
void Shutdown(uint64_t sequence_number, int rank);
void Shutdown(uint64_t sequence_number, std::size_t rank);
/**
* @brief Perform allgather.
@@ -62,7 +63,18 @@ class InMemoryHandler {
* @param rank Index of the worker.
*/
void Allgather(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, int rank);
std::size_t sequence_number, std::size_t rank);
/**
* @brief Perform variable-length allgather.
* @param input The input buffer.
* @param bytes Number of bytes in the input buffer.
* @param output The output buffer.
* @param sequence_number Call sequence number.
* @param rank Index of the worker.
*/
void AllgatherV(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank);
/**
* @brief Perform allreduce.
@@ -75,7 +87,7 @@ class InMemoryHandler {
* @param op The reduce operation.
*/
void Allreduce(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, int rank, DataType data_type, Operation op);
std::size_t sequence_number, std::size_t rank, DataType data_type, Operation op);
/**
* @brief Perform broadcast.
@@ -87,7 +99,7 @@ class InMemoryHandler {
* @param root Index of the worker to broadcast from.
*/
void Broadcast(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, int rank, int root);
std::size_t sequence_number, std::size_t rank, std::size_t root);
private:
/**
@@ -102,15 +114,16 @@ class InMemoryHandler {
*/
template <class HandlerFunctor>
void Handle(char const* input, std::size_t size, std::string* output, std::size_t sequence_number,
int rank, HandlerFunctor const& functor);
std::size_t rank, HandlerFunctor const& functor);
int world_size_{}; /// Number of workers.
int received_{}; /// Number of calls received with the current sequence.
int sent_{}; /// Number of calls completed with the current sequence.
std::string buffer_{}; /// A shared common buffer.
uint64_t sequence_number_{}; /// Call sequence number.
mutable std::mutex mutex_; /// Lock.
mutable std::condition_variable cv_; /// Conditional variable to wait on.
std::size_t world_size_{}; /// Number of workers.
std::size_t received_{}; /// Number of calls received with the current sequence.
std::size_t sent_{}; /// Number of calls completed with the current sequence.
std::string buffer_{}; /// A shared common buffer.
std::map<std::size_t, std::string_view> aux_{}; /// A shared auxiliary map.
uint64_t sequence_number_{}; /// Call sequence number.
mutable std::mutex mutex_; /// Lock.
mutable std::condition_variable cv_; /// Conditional variable to wait on.
};
} // namespace collective

167
src/collective/loop.cc Normal file
View File

@@ -0,0 +1,167 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include "loop.h"
#include <queue> // for queue
#include "rabit/internal/socket.h" // for PollHelper
#include "xgboost/collective/socket.h" // for FailWithCode
#include "xgboost/logging.h" // for CHECK
namespace xgboost::collective {
Result Loop::EmptyQueue() {
timer_.Start(__func__);
auto error = [this] {
this->stop_ = true;
timer_.Stop(__func__);
};
while (!queue_.empty() && !stop_) {
std::queue<Op> qcopy;
rabit::utils::PollHelper poll;
// watch all ops
while (!queue_.empty()) {
auto op = queue_.front();
queue_.pop();
switch (op.code) {
case Op::kRead: {
poll.WatchRead(*op.sock);
break;
}
case Op::kWrite: {
poll.WatchWrite(*op.sock);
break;
}
default: {
error();
return Fail("Invalid socket operation.");
}
}
qcopy.push(op);
}
// poll, work on fds that are ready.
timer_.Start("poll");
auto rc = poll.Poll(timeout_);
timer_.Stop("poll");
if (!rc.OK()) {
error();
return rc;
}
// we wonldn't be here if the queue is empty.
CHECK(!qcopy.empty());
while (!qcopy.empty() && !stop_) {
auto op = qcopy.front();
qcopy.pop();
std::int32_t n_bytes_done{0};
CHECK(op.sock->NonBlocking());
switch (op.code) {
case Op::kRead: {
if (poll.CheckRead(*op.sock)) {
n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off);
}
break;
}
case Op::kWrite: {
if (poll.CheckWrite(*op.sock)) {
n_bytes_done = op.sock->Send(op.ptr + op.off, op.n - op.off);
}
break;
}
default: {
error();
return Fail("Invalid socket operation.");
}
}
if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) {
stop_ = true;
auto rc = system::FailWithCode("Invalid socket output.");
error();
return rc;
}
op.off += n_bytes_done;
CHECK_LE(op.off, op.n);
if (op.off != op.n) {
// not yet finished, push back to queue for next round.
queue_.push(op);
}
}
}
timer_.Stop(__func__);
return Success();
}
void Loop::Process() {
// consumer
while (true) {
std::unique_lock lock{mu_};
cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; });
if (stop_) {
break;
}
CHECK(!mu_.try_lock());
this->rc_ = this->EmptyQueue();
if (!rc_.OK()) {
stop_ = true;
cv_.notify_one();
break;
}
CHECK(queue_.empty());
CHECK(!mu_.try_lock());
cv_.notify_one();
}
if (rc_.OK()) {
CHECK(queue_.empty());
}
}
Result Loop::Stop() {
std::unique_lock lock{mu_};
stop_ = true;
lock.unlock();
CHECK_EQ(this->Block().OK(), this->rc_.OK());
if (curr_exce_) {
std::rethrow_exception(curr_exce_);
}
return Success();
}
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
timer_.Init(__func__);
worker_ = std::thread{[this] {
try {
this->Process();
} catch (std::exception const& e) {
std::lock_guard<std::mutex> guard{mu_};
if (!curr_exce_) {
curr_exce_ = std::current_exception();
rc_ = Fail("Exception was thrown");
}
stop_ = true;
cv_.notify_all();
} catch (...) {
std::lock_guard<std::mutex> guard{mu_};
if (!curr_exce_) {
curr_exce_ = std::current_exception();
rc_ = Fail("Exception was thrown");
}
stop_ = true;
cv_.notify_all();
}
}};
}
} // namespace xgboost::collective

83
src/collective/loop.h Normal file
View File

@@ -0,0 +1,83 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <chrono> // for seconds
#include <condition_variable> // for condition_variable
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int32_t
#include <exception> // for exception_ptr
#include <mutex> // for unique_lock, mutex
#include <queue> // for queue
#include <thread> // for thread
#include <utility> // for move
#include "../common/timer.h" // for Monitor
#include "xgboost/collective/result.h" // for Result
#include "xgboost/collective/socket.h" // for TCPSocket
namespace xgboost::collective {
class Loop {
public:
struct Op {
enum Code : std::int8_t { kRead = 0, kWrite = 1 } code;
std::int32_t rank{-1};
std::int8_t* ptr{nullptr};
std::size_t n{0};
TCPSocket* sock{nullptr};
std::size_t off{0};
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
Op(Op const&) = default;
Op& operator=(Op const&) = default;
Op(Op&&) = default;
Op& operator=(Op&&) = default;
};
private:
std::thread worker_;
std::condition_variable cv_;
std::mutex mu_;
std::queue<Op> queue_;
std::chrono::seconds timeout_;
Result rc_;
bool stop_{false};
std::exception_ptr curr_exce_{nullptr};
common::Monitor timer_;
Result EmptyQueue();
void Process();
public:
Result Stop();
void Submit(Op op) {
// producer
std::unique_lock lock{mu_};
queue_.push(op);
lock.unlock();
cv_.notify_one();
}
[[nodiscard]] Result Block() {
{
std::unique_lock lock{mu_};
cv_.notify_all();
}
std::unique_lock lock{mu_};
cv_.wait(lock, [this] { return this->queue_.empty() || stop_; });
return std::move(rc_);
}
explicit Loop(std::chrono::seconds timeout);
~Loop() noexcept(false) {
this->Stop();
if (worker_.joinable()) {
worker_.join();
}
}
};
} // namespace xgboost::collective

View File

@@ -17,10 +17,11 @@ class NoOpCommunicator : public Communicator {
NoOpCommunicator() : Communicator(1, 0) {}
bool IsDistributed() const override { return false; }
bool IsFederated() const override { return false; }
void AllGather(void *, std::size_t) override {}
std::string AllGather(std::string_view) override { return {}; }
std::string AllGatherV(std::string_view) override { return {}; }
void AllReduce(void *, std::size_t, DataType, Operation) override {}
void Broadcast(void *, std::size_t, int) override {}
std::string GetProcessorName() override { return ""; }
std::string GetProcessorName() override { return {}; }
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
protected:

214
src/collective/protocol.h Normal file
View File

@@ -0,0 +1,214 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int32_t
#include <string> // for string
#include <utility> // for move
#include "xgboost/collective/result.h" // for Result
#include "xgboost/collective/socket.h" // for TCPSocket
#include "xgboost/json.h" // for Json
namespace xgboost::collective::proto {
struct PeerInfo {
std::string host;
std::int32_t port{-1};
std::int32_t rank{-1};
PeerInfo() = default;
PeerInfo(std::string host, std::int32_t port, std::int32_t rank)
: host{std::move(host)}, port{port}, rank{rank} {}
explicit PeerInfo(Json const& peer)
: host{get<String>(peer["host"])},
port{static_cast<std::int32_t>(get<Integer const>(peer["port"]))},
rank{static_cast<std::int32_t>(get<Integer const>(peer["rank"]))} {}
[[nodiscard]] Json ToJson() const {
Json info{Object{}};
info["rank"] = rank;
info["host"] = String{host};
info["port"] = Integer{port};
return info;
}
[[nodiscard]] auto HostPort() const { return host + ":" + std::to_string(this->port); }
};
struct Magic {
static constexpr std::int32_t kMagic = 0xff99;
[[nodiscard]] Result Verify(xgboost::collective::TCPSocket* p_sock) {
std::int32_t magic{kMagic};
auto n_bytes = p_sock->SendAll(&magic, sizeof(magic));
if (n_bytes != sizeof(magic)) {
return Fail("Failed to verify.");
}
magic = 0;
n_bytes = p_sock->RecvAll(&magic, sizeof(magic));
if (n_bytes != sizeof(magic)) {
return Fail("Failed to verify.");
}
if (magic != kMagic) {
return xgboost::collective::Fail("Invalid verification number.");
}
return Success();
}
};
enum class CMD : std::int32_t {
kInvalid = 0,
kStart = 1,
kShutdown = 2,
kError = 3,
kPrint = 4,
};
struct Connect {
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::int32_t world, std::int32_t rank,
std::string task_id) const {
Json jinit{Object{}};
jinit["world_size"] = Integer{world};
jinit["rank"] = Integer{rank};
jinit["task_id"] = String{task_id};
std::string msg;
Json::Dump(jinit, &msg);
auto n_bytes = tracker->Send(msg);
if (n_bytes != msg.size()) {
return Fail("Failed to send init command from worker.");
}
return Success();
}
[[nodiscard]] Result TrackerRecv(TCPSocket* sock, std::int32_t* world, std::int32_t* rank,
std::string* task_id) const {
std::string init;
sock->Recv(&init);
auto jinit = Json::Load(StringView{init});
*world = get<Integer const>(jinit["world_size"]);
*rank = get<Integer const>(jinit["rank"]);
*task_id = get<String const>(jinit["task_id"]);
return Success();
}
};
class Start {
private:
[[nodiscard]] Result TrackerSend(std::int32_t world, TCPSocket* worker) const {
Json jcmd{Object{}};
jcmd["world_size"] = Integer{world};
auto scmd = Json::Dump(jcmd);
auto n_bytes = worker->Send(scmd);
if (n_bytes != scmd.size()) {
return Fail("Failed to send init command from tracker.");
}
return Success();
}
public:
[[nodiscard]] Result WorkerSend(std::int32_t lport, TCPSocket* tracker,
std::int32_t eport) const {
Json jcmd{Object{}};
jcmd["cmd"] = Integer{static_cast<std::int32_t>(CMD::kStart)};
jcmd["port"] = Integer{lport};
jcmd["error_port"] = Integer{eport};
auto scmd = Json::Dump(jcmd);
auto n_bytes = tracker->Send(scmd);
if (n_bytes != scmd.size()) {
return Fail("Failed to send init command from worker.");
}
return Success();
}
[[nodiscard]] Result WorkerRecv(TCPSocket* tracker, std::int32_t* p_world) const {
std::string scmd;
auto n_bytes = tracker->Recv(&scmd);
if (n_bytes <= 0) {
return Fail("Failed to recv init command from tracker.");
}
auto jcmd = Json::Load(scmd);
auto world = get<Integer const>(jcmd["world_size"]);
if (world <= 0) {
return Fail("Invalid world size.");
}
*p_world = world;
return Success();
}
[[nodiscard]] Result TrackerHandle(Json jcmd, std::int32_t* recv_world, std::int32_t world,
std::int32_t* p_port, TCPSocket* p_sock,
std::int32_t* eport) const {
*p_port = get<Integer const>(jcmd["port"]);
if (*p_port <= 0) {
return Fail("Invalid port.");
}
if (*recv_world != -1) {
return Fail("Invalid initialization sequence.");
}
*recv_world = world;
*eport = get<Integer const>(jcmd["error_port"]);
return TrackerSend(world, p_sock);
}
};
struct Print {
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::string msg) const {
Json jcmd{Object{}};
jcmd["cmd"] = Integer{static_cast<std::int32_t>(CMD::kPrint)};
jcmd["msg"] = String{std::move(msg)};
auto scmd = Json::Dump(jcmd);
auto n_bytes = tracker->Send(scmd);
if (n_bytes != scmd.size()) {
return Fail("Failed to send print command from worker.");
}
return Success();
}
[[nodiscard]] Result TrackerHandle(Json jcmd, std::string* p_msg) const {
if (!IsA<String>(jcmd["msg"])) {
return Fail("Invalid print command.");
}
auto msg = get<String const>(jcmd["msg"]);
*p_msg = msg;
return Success();
}
};
struct ErrorCMD {
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, Result const& res) const {
auto msg = res.Report();
auto code = res.Code().value();
Json jcmd{Object{}};
jcmd["msg"] = String{std::move(msg)};
jcmd["code"] = Integer{code};
jcmd["cmd"] = Integer{static_cast<std::int32_t>(CMD::kError)};
auto scmd = Json::Dump(jcmd);
auto n_bytes = tracker->Send(scmd);
if (n_bytes != scmd.size()) {
return Fail("Failed to send error command from worker.");
}
return Success();
}
[[nodiscard]] Result TrackerHandle(Json jcmd, std::string* p_msg, int* p_code) const {
if (!IsA<String>(jcmd["msg"]) || !IsA<Integer>(jcmd["code"])) {
return Fail("Invalid error command.");
}
auto msg = get<String const>(jcmd["msg"]);
auto code = get<Integer const>(jcmd["code"]);
*p_msg = msg;
*p_code = code;
return Success();
}
};
struct ShutdownCMD {
[[nodiscard]] Result Send(TCPSocket* peer) const {
Json jcmd{Object{}};
jcmd["cmd"] = Integer{static_cast<std::int32_t>(proto::CMD::kShutdown)};
auto scmd = Json::Dump(jcmd);
auto n_bytes = peer->Send(scmd);
if (n_bytes != scmd.size()) {
return Fail("Failed to send shutdown command from worker.");
}
return Success();
}
};
} // namespace xgboost::collective::proto

View File

@@ -7,6 +7,7 @@
#include <string>
#include <vector>
#include "communicator-inl.h"
#include "communicator.h"
#include "xgboost/json.h"
@@ -55,10 +56,29 @@ class RabitCommunicator : public Communicator {
bool IsFederated() const override { return false; }
void AllGather(void *send_receive_buffer, std::size_t size) override {
auto const per_rank = size / GetWorldSize();
std::string AllGather(std::string_view input) override {
auto const per_rank = input.size();
auto const total_size = per_rank * GetWorldSize();
auto const index = per_rank * GetRank();
rabit::Allgather(static_cast<char *>(send_receive_buffer), size, index, per_rank, per_rank);
std::string result(total_size, '\0');
result.replace(index, per_rank, input);
rabit::Allgather(result.data(), total_size, index, per_rank, per_rank);
return result;
}
std::string AllGatherV(std::string_view input) override {
auto const size_node_slice = input.size();
auto const all_sizes = collective::Allgather(size_node_slice);
auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul);
auto const begin_index =
std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul);
auto const size_prev_slice =
GetRank() == 0 ? all_sizes[GetWorldSize() - 1] : all_sizes[GetRank() - 1];
std::string result(total_size, '\0');
result.replace(begin_index, size_node_slice, input);
rabit::Allgather(result.data(), total_size, begin_index, size_node_slice, size_prev_slice);
return result;
}
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,

View File

@@ -3,6 +3,7 @@
*/
#include "xgboost/collective/socket.h"
#include <array> // for array
#include <cstddef> // std::size_t
#include <cstdint> // std::int32_t
#include <cstring> // std::memcpy, std::memset
@@ -92,13 +93,18 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
conn = TCPSocket::Create(addr.Domain());
CHECK_EQ(static_cast<std::int32_t>(conn.Domain()), static_cast<std::int32_t>(addr.Domain()));
conn.SetNonBlock(true);
auto non_blocking = conn.NonBlocking();
auto rc = conn.NonBlocking(true);
if (!rc.OK()) {
return Fail("Failed to set socket option.", std::move(rc));
}
Result last_error;
auto log_failure = [&host, &last_error](Result err, char const *file, std::int32_t line) {
auto log_failure = [&host, &last_error, port](Result err, char const *file, std::int32_t line) {
last_error = std::move(err);
LOG(WARNING) << std::filesystem::path{file}.filename().string() << "(" << line
<< "): Failed to connect to:" << host << " Error:" << last_error.Report();
<< "): Failed to connect to:" << host << ":" << port
<< " Error:" << last_error.Report();
};
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
@@ -112,39 +118,42 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
}
auto rc = connect(conn.Handle(), addr_handle, addr_len);
if (rc != 0) {
auto errcode = system::LastError();
if (!system::ErrorWouldBlock(errcode)) {
log_failure(Fail("connect failed.", std::error_code{errcode, std::system_category()}),
__FILE__, __LINE__);
continue;
}
rabit::utils::PollHelper poll;
poll.WatchWrite(conn);
auto result = poll.Poll(timeout);
if (!result.OK()) {
log_failure(std::move(result), __FILE__, __LINE__);
continue;
}
if (!poll.CheckWrite(conn)) {
log_failure(Fail("poll failed.", std::error_code{errcode, std::system_category()}),
__FILE__, __LINE__);
continue;
}
result = conn.GetSockError();
if (!result.OK()) {
log_failure(std::move(result), __FILE__, __LINE__);
continue;
}
conn.SetNonBlock(false);
return Success();
} else {
conn.SetNonBlock(false);
return Success();
if (rc == 0) {
return conn.NonBlocking(non_blocking);
}
auto errcode = system::LastError();
if (!system::ErrorWouldBlock(errcode)) {
log_failure(Fail("connect failed.", std::error_code{errcode, std::system_category()}),
__FILE__, __LINE__);
continue;
}
rabit::utils::PollHelper poll;
poll.WatchWrite(conn);
auto result = poll.Poll(timeout);
if (!result.OK()) {
// poll would fail if there's a socket error, we log the root cause instead of the
// poll failure.
auto sockerr = conn.GetSockError();
if (!sockerr.OK()) {
result = std::move(sockerr);
}
log_failure(std::move(result), __FILE__, __LINE__);
continue;
}
if (!poll.CheckWrite(conn)) {
log_failure(Fail("poll failed.", std::error_code{errcode, std::system_category()}), __FILE__,
__LINE__);
continue;
}
result = conn.GetSockError();
if (!result.OK()) {
log_failure(std::move(result), __FILE__, __LINE__);
continue;
}
return conn.NonBlocking(non_blocking);
}
std::stringstream ss;
@@ -152,4 +161,13 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
conn.Close();
return Fail(ss.str(), std::move(last_error));
}
[[nodiscard]] Result GetHostName(std::string *p_out) {
std::array<char, HOST_NAME_MAX> buf;
if (gethostname(&buf[0], HOST_NAME_MAX) != 0) {
return system::FailWithCode("Failed to get host name.");
}
*p_out = buf.data();
return Success();
}
} // namespace xgboost::collective

296
src/collective/tracker.cc Normal file
View File

@@ -0,0 +1,296 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#if defined(__unix__) || defined(__APPLE__)
#include <netdb.h> // gethostbyname
#include <sys/socket.h> // socket, AF_INET6, AF_INET, connect, getsockname
#endif // defined(__unix__) || defined(__APPLE__)
#if !defined(NOMINMAX) && defined(_WIN32)
#define NOMINMAX
#endif // !defined(NOMINMAX)
#if defined(_WIN32)
#include <winsock2.h>
#include <ws2tcpip.h>
#endif // defined(_WIN32)
#include <algorithm> // for sort
#include <chrono> // for seconds
#include <cstdint> // for int32_t
#include <string> // for string
#include <utility> // for move, forward
#include "../common/json_utils.h"
#include "comm.h"
#include "protocol.h" // for kMagic, PeerInfo
#include "tracker.h"
#include "xgboost/collective/result.h" // for Result, Fail, Success
#include "xgboost/collective/socket.h" // for GetHostName, FailWithCode, MakeSockAddress, ...
#include "xgboost/json.h"
namespace xgboost::collective {
Tracker::Tracker(Json const& config)
: n_workers_{static_cast<std::int32_t>(
RequiredArg<Integer const>(config, "n_workers", __func__))},
port_{static_cast<std::int32_t>(OptionalArg<Integer const>(config, "port", Integer::Int{0}))},
timeout_{std::chrono::seconds{OptionalArg<Integer const>(
config, "timeout", static_cast<std::int64_t>(collective::DefaultTimeoutSec()))}} {}
RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr)
: sock_{std::move(sock)} {
auto host = addr.Addr();
std::int32_t rank{0};
rc_ = Success()
<< [&] { return proto::Magic{}.Verify(&sock_); }
<< [&] { return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_); };
if (!rc_.OK()) {
return;
}
std::string cmd;
sock_.Recv(&cmd);
auto jcmd = Json::Load(StringView{cmd});
cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"]));
std::int32_t port{0};
if (cmd_ == proto::CMD::kStart) {
proto::Start start;
rc_ = start.TrackerHandle(jcmd, &world_, world, &port, &sock_, &eport_);
} else if (cmd_ == proto::CMD::kPrint) {
proto::Print print;
rc_ = print.TrackerHandle(jcmd, &msg_);
} else if (cmd_ == proto::CMD::kError) {
proto::ErrorCMD error;
rc_ = error.TrackerHandle(jcmd, &msg_, &code_);
}
if (!rc_.OK()) {
return;
}
info_ = proto::PeerInfo{host, port, rank};
}
RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
std::string self;
auto rc = collective::GetHostAddress(&self);
auto host = OptionalArg<String>(config, "host", self);
listener_ = TCPSocket::Create(SockDomain::kV4);
rc = listener_.Bind(host, &this->port_);
CHECK(rc.OK()) << rc.Report();
listener_.Listen();
}
Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
auto& workers = *p_workers;
std::sort(workers.begin(), workers.end(), WorkerCmp{});
std::vector<std::thread> bootstrap_threads;
for (std::int32_t r = 0; r < n_workers_; ++r) {
auto& worker = workers[r];
auto next = BootstrapNext(r, n_workers_);
auto const& next_w = workers[next];
bootstrap_threads.emplace_back([next, &worker, &next_w] {
auto jnext = proto::PeerInfo{next_w.Host(), next_w.Port(), next}.ToJson();
std::string str;
Json::Dump(jnext, &str);
worker.Send(StringView{str});
});
}
for (auto& t : bootstrap_threads) {
t.join();
}
for (auto const& w : workers) {
worker_error_handles_.emplace_back(w.Host(), w.ErrorPort());
}
return Success();
}
[[nodiscard]] std::future<Result> RabitTracker::Run() {
// a state machine to keep track of consistency.
struct State {
std::int32_t const n_workers;
std::int32_t n_shutdown{0};
bool during_restart{false};
std::vector<WorkerProxy> pending;
explicit State(std::int32_t world) : n_workers{world} {}
State(State const& that) = delete;
State& operator=(State&& that) = delete;
void Start(WorkerProxy&& worker) {
CHECK_LT(pending.size(), n_workers);
CHECK_LE(n_shutdown, n_workers);
pending.emplace_back(std::forward<WorkerProxy>(worker));
CHECK_LE(pending.size(), n_workers);
}
void Shutdown() {
CHECK_GE(n_shutdown, 0);
CHECK_LT(n_shutdown, n_workers);
++n_shutdown;
CHECK_LE(n_shutdown, n_workers);
}
void Error() {
CHECK_LE(pending.size(), n_workers);
CHECK_LE(n_shutdown, n_workers);
during_restart = true;
}
[[nodiscard]] bool Ready() const {
CHECK_LE(pending.size(), n_workers);
return static_cast<std::int32_t>(pending.size()) == n_workers;
}
void Bootstrap() {
CHECK_EQ(pending.size(), n_workers);
CHECK_LE(n_shutdown, n_workers);
// A reset.
n_shutdown = 0;
during_restart = false;
pending.clear();
}
[[nodiscard]] bool ShouldContinue() const {
CHECK_LE(pending.size(), n_workers);
CHECK_LE(n_shutdown, n_workers);
// - Without error, we should shutdown after all workers are offline.
// - With error, all workers are offline, and we have during_restart as true.
return n_shutdown != n_workers || during_restart;
}
};
return std::async(std::launch::async, [this] {
State state{this->n_workers_};
while (state.ShouldContinue()) {
TCPSocket sock;
SockAddrV4 addr;
auto rc = listener_.Accept(&sock, &addr);
if (!rc.OK()) {
return Fail("Failed to accept connection.", std::move(rc));
}
auto worker = WorkerProxy{n_workers_, std::move(sock), std::move(addr)};
if (!worker.Status().OK()) {
return Fail("Failed to initialize worker proxy.", std::move(worker.Status()));
}
switch (worker.Command()) {
case proto::CMD::kStart: {
state.Start(std::move(worker));
if (state.Ready()) {
rc = this->Bootstrap(&state.pending);
state.Bootstrap();
}
if (!rc.OK()) {
return rc;
}
continue;
}
case proto::CMD::kShutdown: {
state.Shutdown();
continue;
}
case proto::CMD::kError: {
if (state.during_restart) {
continue;
}
state.Error();
auto msg = worker.Msg();
auto code = worker.Code();
LOG(WARNING) << "Recieved error from [" << worker.Host() << ":" << worker.Rank()
<< "]: " << msg << " code:" << code;
auto host = worker.Host();
// We signal all workers for the error, if they haven't aborted already.
for (auto& w : worker_error_handles_) {
if (w.first == host) {
continue;
}
TCPSocket out;
// retry is set to 1, just let the worker timeout or error. Otherwise the
// tracker and the worker might be waiting for each other.
auto rc = Connect(w.first, w.second, 1, timeout_, &out);
// send signal to stop the worker.
proto::ShutdownCMD shutdown;
rc = shutdown.Send(&out);
if (!rc.OK()) {
return Fail("Failed to inform workers to stop.");
}
}
continue;
}
case proto::CMD::kPrint: {
LOG(CONSOLE) << worker.Msg();
continue;
}
case proto::CMD::kInvalid:
default: {
return Fail("Invalid command received.");
}
}
}
return Success();
});
}
[[nodiscard]] Result GetHostAddress(std::string* out) {
auto rc = GetHostName(out);
if (!rc.OK()) {
return rc;
}
auto host = gethostbyname(out->c_str());
// get ip address from host
std::string ip;
rc = INetNToP(host, &ip);
if (!rc.OK()) {
return rc;
}
if (!(ip.size() >= 4 && ip.substr(0, 4) == "127.")) {
// return if this is a public IP address.
// not entirely accurate, we have other reserved IPs
*out = ip;
return Success();
}
// Create an UDP socket to prob the public IP address, it's fine even if it's
// unreachable.
auto sock = socket(AF_INET, SOCK_DGRAM, 0);
if (sock == -1) {
return Fail("Failed to create socket.");
}
auto paddr = MakeSockAddress(StringView{"10.255.255.255"}, 1);
sockaddr const* addr_handle = reinterpret_cast<const sockaddr*>(&paddr.V4().Handle());
socklen_t addr_len{sizeof(paddr.V4().Handle())};
auto err = connect(sock, addr_handle, addr_len);
if (err != 0) {
return system::FailWithCode("Failed to find IP address.");
}
// get the IP address from socket desrciptor
struct sockaddr_in addr;
socklen_t len = sizeof(addr);
if (getsockname(sock, reinterpret_cast<struct sockaddr*>(&addr), &len) == -1) {
return Fail("Failed to get sock name.");
}
ip = inet_ntoa(addr.sin_addr);
err = system::CloseSocket(sock);
if (err != 0) {
return system::FailWithCode("Failed to close socket.");
}
*out = ip;
return Success();
}
} // namespace xgboost::collective

141
src/collective/tracker.h Normal file
View File

@@ -0,0 +1,141 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <chrono> // for seconds
#include <cstdint> // for int32_t
#include <future> // for future
#include <string> // for string
#include <utility> // for pair
#include <vector> // for vector
#include "protocol.h"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/collective/socket.h" // for TCPSocket
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
/**
*
* @brief Implementation of RABIT tracker.
*
* * What is a tracker
*
* The implementation of collective follows what RABIT did in the past. It requires a
* tracker to coordinate initialization and error recovery of workers. While the
* original implementation attempted to attain error resislient inside the collective
* module, which turned out be too challenging due to large amount of external
* states. The new implementation here differs from RABIT in the way that neither state
* recovery nor resislient is handled inside the collective, it merely provides the
* mechanism to signal error to other workers through the use of a centralized tracker.
*
* There are three major functionalities provided the a tracker, namely:
* - Initialization. Share the node addresses among all workers.
* - Logging.
* - Signal error. If an exception is thrown in one (or many) of the workers, it can
* signal an error to the tracker and the tracker will notify other workers.
*/
class Tracker {
protected:
std::int32_t n_workers_{0};
std::int32_t port_{-1};
std::chrono::seconds timeout_{0};
public:
explicit Tracker(Json const& config);
Tracker(std::int32_t n_worders, std::int32_t port, std::chrono::seconds timeout)
: n_workers_{n_worders}, port_{port}, timeout_{timeout} {}
virtual ~Tracker() noexcept(false){}; // NOLINT
[[nodiscard]] virtual std::future<Result> Run() = 0;
[[nodiscard]] virtual Json WorkerArgs() const = 0;
[[nodiscard]] std::chrono::seconds Timeout() const { return timeout_; }
};
class RabitTracker : public Tracker {
// a wrapper for connected worker socket.
class WorkerProxy {
TCPSocket sock_;
proto::PeerInfo info_;
std::int32_t eport_{0};
std::int32_t world_{-1};
std::string task_id_;
proto::CMD cmd_{proto::CMD::kInvalid};
std::string msg_;
std::int32_t code_{0};
Result rc_;
public:
explicit WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr);
WorkerProxy(WorkerProxy const& that) = delete;
WorkerProxy(WorkerProxy&& that) = default;
WorkerProxy& operator=(WorkerProxy const&) = delete;
WorkerProxy& operator=(WorkerProxy&&) = default;
[[nodiscard]] auto Host() const { return info_.host; }
[[nodiscard]] auto TaskID() const { return task_id_; }
[[nodiscard]] auto Port() const { return info_.port; }
[[nodiscard]] auto Rank() const { return info_.rank; }
[[nodiscard]] auto ErrorPort() const { return eport_; }
[[nodiscard]] auto Command() const { return cmd_; }
[[nodiscard]] auto Msg() const { return msg_; }
[[nodiscard]] auto Code() const { return code_; }
[[nodiscard]] Result const& Status() const { return rc_; }
[[nodiscard]] Result& Status() { return rc_; }
void Send(StringView value) { this->sock_.Send(value); }
};
// provide an ordering for workers, this helps us get deterministic topology.
struct WorkerCmp {
[[nodiscard]] bool operator()(WorkerProxy const& lhs, WorkerProxy const& rhs) {
auto const& lh = lhs.Host();
auto const& rh = rhs.Host();
if (lh != rh) {
return lh < rh;
}
return lhs.TaskID() < rhs.TaskID();
}
};
private:
std::string host_;
// record for how to reach out to workers if error happens.
std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_;
// listening socket for incoming workers.
TCPSocket listener_;
Result Bootstrap(std::vector<WorkerProxy>* p_workers);
public:
explicit RabitTracker(StringView host, std::int32_t n_worders, std::int32_t port,
std::chrono::seconds timeout)
: Tracker{n_worders, port, timeout}, host_{host.c_str(), host.size()} {
listener_ = TCPSocket::Create(SockDomain::kV4);
auto rc = listener_.Bind(host, &this->port_);
CHECK(rc.OK()) << rc.Report();
listener_.Listen();
}
explicit RabitTracker(Json const& config);
~RabitTracker() noexcept(false) override = default;
std::future<Result> Run() override;
[[nodiscard]] std::int32_t Port() const { return port_; }
[[nodiscard]] Json WorkerArgs() const override {
Json args{Object{}};
args["DMLC_TRACKER_URI"] = String{host_};
args["DMLC_TRACKER_PORT"] = this->Port();
return args;
}
};
// Prob the public IP address of the host, need a better method.
//
// This is directly translated from the previous Python implementation, we should find a
// more riguous approach, can use some expertise in network programming.
[[nodiscard]] Result GetHostAddress(std::string* out);
} // namespace xgboost::collective