enable ROCm on latest XGBoost
This commit is contained in:
@@ -15,8 +15,7 @@
|
||||
|
||||
#include "communicator-inl.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
namespace xgboost::collective {
|
||||
|
||||
/**
|
||||
* @brief Find the global sum of the given values across all workers.
|
||||
@@ -31,10 +30,9 @@ namespace collective {
|
||||
* @param size Number of values to sum.
|
||||
*/
|
||||
template <typename T>
|
||||
void GlobalSum(MetaInfo const& info, int device, T* values, size_t size) {
|
||||
void GlobalSum(MetaInfo const& info, DeviceOrd device, T* values, size_t size) {
|
||||
if (info.IsRowSplit()) {
|
||||
collective::AllReduce<collective::Operation::kSum>(device, values, size);
|
||||
collective::AllReduce<collective::Operation::kSum>(device.ordinal, values, size);
|
||||
}
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::collective
|
||||
|
||||
88
src/collective/allgather.cc
Normal file
88
src/collective/allgather.cc
Normal file
@@ -0,0 +1,88 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include "allgather.h"
|
||||
|
||||
#include <algorithm> // for min, copy_n
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int8_t, int32_t, int64_t
|
||||
#include <memory> // for shared_ptr
|
||||
#include <numeric> // for partial_sum
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "comm.h" // for Comm, Channel
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective::cpu_impl {
|
||||
Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size_t segment_size,
|
||||
std::int32_t worker_off, std::shared_ptr<Channel> prev_ch,
|
||||
std::shared_ptr<Channel> next_ch) {
|
||||
auto world = comm.World();
|
||||
auto rank = comm.Rank();
|
||||
CHECK_LT(worker_off, world);
|
||||
|
||||
for (std::int32_t r = 0; r < world; ++r) {
|
||||
auto send_rank = (rank + world - r + worker_off) % world;
|
||||
auto send_off = send_rank * segment_size;
|
||||
send_off = std::min(send_off, data.size_bytes());
|
||||
auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off));
|
||||
next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
|
||||
|
||||
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
|
||||
auto recv_off = recv_rank * segment_size;
|
||||
recv_off = std::min(recv_off, data.size_bytes());
|
||||
auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off));
|
||||
prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||
auto rc = prev_ch->Block();
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
|
||||
return Success();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int8_t const> data,
|
||||
common::Span<std::int8_t> erased_result) {
|
||||
auto world = comm.World();
|
||||
auto rank = comm.Rank();
|
||||
|
||||
auto prev = BootstrapPrev(rank, comm.World());
|
||||
auto next = BootstrapNext(rank, comm.World());
|
||||
|
||||
auto prev_ch = comm.Chan(prev);
|
||||
auto next_ch = comm.Chan(next);
|
||||
|
||||
// get worker offset
|
||||
std::vector<std::int64_t> offset(world + 1, 0);
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
|
||||
CHECK_EQ(*offset.cbegin(), 0);
|
||||
|
||||
// copy data
|
||||
auto current = erased_result.subspan(offset[rank], data.size_bytes());
|
||||
auto erased_data = EraseType(data);
|
||||
std::copy_n(erased_data.data(), erased_data.size(), current.data());
|
||||
|
||||
for (std::int32_t r = 0; r < world; ++r) {
|
||||
auto send_rank = (rank + world - r) % world;
|
||||
auto send_off = offset[send_rank];
|
||||
auto send_size = sizes[send_rank];
|
||||
auto send_seg = erased_result.subspan(send_off, send_size);
|
||||
next_ch->SendAll(send_seg);
|
||||
|
||||
auto recv_rank = (rank + world - r - 1) % world;
|
||||
auto recv_off = offset[recv_rank];
|
||||
auto recv_size = sizes[recv_rank];
|
||||
auto recv_seg = erased_result.subspan(recv_off, recv_size);
|
||||
prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||
|
||||
auto rc = prev_ch->Block();
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
return comm.Block();
|
||||
}
|
||||
} // namespace xgboost::collective::cpu_impl
|
||||
72
src/collective/allgather.h
Normal file
72
src/collective/allgather.h
Normal file
@@ -0,0 +1,72 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t
|
||||
#include <memory> // for shared_ptr
|
||||
#include <numeric> // for accumulate
|
||||
#include <type_traits> // for remove_cv_t
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "comm.h" // for Comm, Channel, EraseType
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace cpu_impl {
|
||||
/**
|
||||
* @param worker_off Segment offset. For example, if the rank 2 worker specifis worker_off
|
||||
* = 1, then it owns the third segment.
|
||||
*/
|
||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::size_t segment_size, std::int32_t worker_off,
|
||||
std::shared_ptr<Channel> prev_ch,
|
||||
std::shared_ptr<Channel> next_ch);
|
||||
|
||||
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int8_t const> data,
|
||||
common::Span<std::int8_t> erased_result);
|
||||
} // namespace cpu_impl
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
|
||||
auto n_bytes = sizeof(T) * size;
|
||||
auto erased = EraseType(data);
|
||||
|
||||
auto rank = comm.Rank();
|
||||
auto prev = BootstrapPrev(rank, comm.World());
|
||||
auto next = BootstrapNext(rank, comm.World());
|
||||
|
||||
auto prev_ch = comm.Chan(prev);
|
||||
auto next_ch = comm.Chan(next);
|
||||
auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes, 0, prev_ch, next_ch);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
return comm.Block();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<T> data,
|
||||
std::vector<std::remove_cv_t<T>>* p_out) {
|
||||
auto world = comm.World();
|
||||
auto rank = comm.Rank();
|
||||
|
||||
std::vector<std::int64_t> sizes(world, 0);
|
||||
sizes[rank] = data.size_bytes();
|
||||
auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()}, 1);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
std::vector<T>& result = *p_out;
|
||||
auto n_total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
|
||||
result.resize(n_total_bytes / sizeof(T));
|
||||
auto h_result = common::Span{result.data(), result.size()};
|
||||
auto erased_result = EraseType(h_result);
|
||||
auto erased_data = EraseType(data);
|
||||
|
||||
return cpu_impl::RingAllgatherV(comm, sizes, erased_data, erased_result);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
90
src/collective/allreduce.cc
Normal file
90
src/collective/allreduce.cc
Normal file
@@ -0,0 +1,90 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include "allreduce.h"
|
||||
|
||||
#include <algorithm> // for min
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t, int8_t
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../data/array_interface.h" // for Type, DispatchDType
|
||||
#include "allgather.h" // for RingAllgather
|
||||
#include "comm.h" // for Comm
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective::cpu_impl {
|
||||
template <typename T>
|
||||
Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::size_t n_bytes_in_seg, Func const& op) {
|
||||
auto rank = comm.Rank();
|
||||
auto world = comm.World();
|
||||
|
||||
auto dst_rank = BootstrapNext(rank, world);
|
||||
auto src_rank = BootstrapPrev(rank, world);
|
||||
auto next_ch = comm.Chan(dst_rank);
|
||||
auto prev_ch = comm.Chan(src_rank);
|
||||
|
||||
std::vector<std::int8_t> buffer(n_bytes_in_seg, 0);
|
||||
auto s_buf = common::Span{buffer.data(), buffer.size()};
|
||||
|
||||
for (std::int32_t r = 0; r < world - 1; ++r) {
|
||||
// send to ring next
|
||||
auto send_off = ((rank + world - r) % world) * n_bytes_in_seg;
|
||||
send_off = std::min(send_off, data.size_bytes());
|
||||
auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg);
|
||||
auto send_seg = data.subspan(send_off, seg_nbytes);
|
||||
|
||||
next_ch->SendAll(send_seg);
|
||||
|
||||
// receive from ring prev
|
||||
auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg;
|
||||
recv_off = std::min(recv_off, data.size_bytes());
|
||||
seg_nbytes = std::min(data.size_bytes() - recv_off, n_bytes_in_seg);
|
||||
CHECK_EQ(seg_nbytes % sizeof(T), 0);
|
||||
auto recv_seg = data.subspan(recv_off, seg_nbytes);
|
||||
auto seg = s_buf.subspan(0, recv_seg.size());
|
||||
|
||||
prev_ch->RecvAll(seg);
|
||||
auto rc = prev_ch->Block();
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
// accumulate to recv_seg
|
||||
CHECK_EQ(seg.size(), recv_seg.size());
|
||||
op(seg, recv_seg);
|
||||
}
|
||||
|
||||
return Success();
|
||||
}
|
||||
|
||||
Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func const& op,
|
||||
ArrayInterfaceHandler::Type type) {
|
||||
return DispatchDType(type, [&](auto t) {
|
||||
using T = decltype(t);
|
||||
// Divide the data into segments according to the number of workers.
|
||||
auto n_bytes_elem = sizeof(T);
|
||||
CHECK_EQ(data.size_bytes() % n_bytes_elem, 0);
|
||||
auto n = data.size_bytes() / n_bytes_elem;
|
||||
auto world = comm.World();
|
||||
auto n_bytes_in_seg = common::DivRoundUp(n, world) * sizeof(T);
|
||||
auto rc = RingScatterReduceTyped<T>(comm, data, n_bytes_in_seg, op);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
auto prev = BootstrapPrev(comm.Rank(), comm.World());
|
||||
auto next = BootstrapNext(comm.Rank(), comm.World());
|
||||
auto prev_ch = comm.Chan(prev);
|
||||
auto next_ch = comm.Chan(next);
|
||||
|
||||
rc = RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
return comm.Block();
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective::cpu_impl
|
||||
39
src/collective/allreduce.h
Normal file
39
src/collective/allreduce.h
Normal file
@@ -0,0 +1,39 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstdint> // for int8_t
|
||||
#include <functional> // for function
|
||||
#include <type_traits> // for is_invocable_v
|
||||
|
||||
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
||||
#include "comm.h" // for Comm, RestoreType
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace cpu_impl {
|
||||
using Func =
|
||||
std::function<void(common::Span<std::int8_t const> lhs, common::Span<std::int8_t> out)>;
|
||||
|
||||
Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func const& op,
|
||||
ArrayInterfaceHandler::Type type);
|
||||
} // namespace cpu_impl
|
||||
|
||||
template <typename T, typename Fn>
|
||||
std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>, Result> Allreduce(
|
||||
Comm const& comm, common::Span<T> data, Fn redop) {
|
||||
auto erased = EraseType(data);
|
||||
auto type = ToDType<T>::kType;
|
||||
|
||||
auto erased_fn = [type, redop](common::Span<std::int8_t const> lhs,
|
||||
common::Span<std::int8_t> out) {
|
||||
CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction.";
|
||||
auto lhs_t = RestoreType<T const>(lhs);
|
||||
auto rhs_t = RestoreType<T>(out);
|
||||
redop(lhs_t, rhs_t);
|
||||
};
|
||||
|
||||
return cpu_impl::RingAllreduce(comm, erased, erased_fn, type);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
84
src/collective/broadcast.cc
Normal file
84
src/collective/broadcast.cc
Normal file
@@ -0,0 +1,84 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include "broadcast.h"
|
||||
|
||||
#include <cmath> // for ceil, log2
|
||||
#include <cstdint> // for int32_t, int8_t
|
||||
#include <utility> // for move
|
||||
|
||||
#include "../common/bitfield.h" // for TrailingZeroBits, RBitField32
|
||||
#include "comm.h" // for Comm
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective::cpu_impl {
|
||||
namespace {
|
||||
std::int32_t ShiftedParentRank(std::int32_t shifted_rank, std::int32_t depth) {
|
||||
std::uint32_t mask{std::uint32_t{0} - 1}; // Oxff...
|
||||
RBitField32 maskbits{common::Span<std::uint32_t>{&mask, 1}};
|
||||
RBitField32 rankbits{
|
||||
common::Span<std::uint32_t>{reinterpret_cast<std::uint32_t*>(&shifted_rank), 1}};
|
||||
// prepare for counting trailing zeros.
|
||||
for (std::int32_t i = 0; i < depth + 1; ++i) {
|
||||
if (rankbits.Check(i)) {
|
||||
maskbits.Set(i);
|
||||
} else {
|
||||
maskbits.Clear(i);
|
||||
}
|
||||
}
|
||||
|
||||
CHECK_NE(mask, 0);
|
||||
auto k = TrailingZeroBits(mask);
|
||||
auto shifted_parent = shifted_rank - (1 << k);
|
||||
return shifted_parent;
|
||||
}
|
||||
|
||||
// Shift the root node to rank 0
|
||||
std::int32_t ShiftLeft(std::int32_t rank, std::int32_t world, std::int32_t root) {
|
||||
auto shifted_rank = (rank + world - root) % world;
|
||||
return shifted_rank;
|
||||
}
|
||||
// shift back to the original rank
|
||||
std::int32_t ShiftRight(std::int32_t rank, std::int32_t world, std::int32_t root) {
|
||||
auto orig = (rank + root) % world;
|
||||
return orig;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t root) {
|
||||
// Binomial tree broadcast
|
||||
// * Wiki
|
||||
// https://en.wikipedia.org/wiki/Broadcast_(parallel_pattern)#Binomial_Tree_Broadcast
|
||||
// * Impl
|
||||
// https://people.mpi-inf.mpg.de/~mehlhorn/ftp/NewToolbox/collective.pdf
|
||||
|
||||
auto rank = comm.Rank();
|
||||
auto world = comm.World();
|
||||
|
||||
// shift root to rank 0
|
||||
auto shifted_rank = ShiftLeft(rank, world, root);
|
||||
std::int32_t depth = std::ceil(std::log2(static_cast<double>(world))) - 1;
|
||||
|
||||
if (shifted_rank != 0) { // not root
|
||||
auto parent = ShiftRight(ShiftedParentRank(shifted_rank, depth), world, root);
|
||||
comm.Chan(parent)->RecvAll(data);
|
||||
auto rc = comm.Chan(parent)->Block();
|
||||
if (!rc.OK()) {
|
||||
return Fail("broadcast failed.", std::move(rc));
|
||||
}
|
||||
}
|
||||
|
||||
for (std::int32_t i = depth; i >= 0; --i) {
|
||||
CHECK_GE((i + 1), 0); // weird clang-tidy error that i might be negative
|
||||
if (shifted_rank % (1 << (i + 1)) == 0 && shifted_rank + (1 << i) < world) {
|
||||
auto sft_peer = shifted_rank + (1 << i);
|
||||
auto peer = ShiftRight(sft_peer, world, root);
|
||||
CHECK_NE(peer, root);
|
||||
comm.Chan(peer)->SendAll(data);
|
||||
}
|
||||
}
|
||||
|
||||
return comm.Block();
|
||||
}
|
||||
} // namespace xgboost::collective::cpu_impl
|
||||
26
src/collective/broadcast.h
Normal file
26
src/collective/broadcast.h
Normal file
@@ -0,0 +1,26 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstdint> // for int32_t, int8_t
|
||||
|
||||
#include "comm.h" // for Comm
|
||||
#include "xgboost/collective/result.h" // for
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace cpu_impl {
|
||||
Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t root);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief binomial tree broadcast is used on CPU with the default implementation.
|
||||
*/
|
||||
template <typename T>
|
||||
[[nodiscard]] Result Broadcast(Comm const& comm, common::Span<T> data, std::int32_t root) {
|
||||
auto n_total_bytes = data.size_bytes();
|
||||
auto erased =
|
||||
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes};
|
||||
return cpu_impl::Broadcast(comm, erased, root);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
304
src/collective/comm.cc
Normal file
304
src/collective/comm.cc
Normal file
@@ -0,0 +1,304 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include "comm.h"
|
||||
|
||||
#include <algorithm> // for copy
|
||||
#include <chrono> // for seconds
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string> // for string
|
||||
#include <utility> // for move, forward
|
||||
|
||||
#include "allgather.h"
|
||||
#include "protocol.h" // for kMagic
|
||||
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
|
||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
||||
#include "xgboost/json.h" // for Json, Object
|
||||
#include "xgboost/string_view.h" // for StringView
|
||||
|
||||
namespace xgboost::collective {
|
||||
Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t retry, std::string task_id)
|
||||
: timeout_{timeout},
|
||||
retry_{retry},
|
||||
tracker_{host, port, -1},
|
||||
task_id_{std::move(task_id)},
|
||||
loop_{std::make_shared<Loop>(timeout)} {}
|
||||
|
||||
Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
|
||||
std::string const& task_id, TCPSocket* out, std::int32_t rank,
|
||||
std::int32_t world) {
|
||||
// get information from tracker
|
||||
CHECK(!info.host.empty());
|
||||
auto rc = Connect(info.host, info.port, retry, timeout, out);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to connect to the tracker.", std::move(rc));
|
||||
}
|
||||
|
||||
TCPSocket& tracker = *out;
|
||||
return std::move(rc)
|
||||
<< [&] { return tracker.NonBlocking(false); }
|
||||
<< [&] { return tracker.RecvTimeout(timeout); }
|
||||
<< [&] { return proto::Magic{}.Verify(&tracker); }
|
||||
<< [&] { return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id); };
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Comm::ConnectTracker(TCPSocket* out) const {
|
||||
return ConnectTrackerImpl(this->TrackerInfo(), this->Timeout(), this->retry_, this->task_id_, out,
|
||||
this->Rank(), this->World());
|
||||
}
|
||||
|
||||
[[nodiscard]] Result ConnectWorkers(Comm const& comm, TCPSocket* listener, std::int32_t lport,
|
||||
proto::PeerInfo ninfo, std::chrono::seconds timeout,
|
||||
std::int32_t retry,
|
||||
std::vector<std::shared_ptr<TCPSocket>>* out_workers) {
|
||||
auto next = std::make_shared<TCPSocket>();
|
||||
auto prev = std::make_shared<TCPSocket>();
|
||||
|
||||
auto rc = Success() << [&] {
|
||||
auto rc = Connect(ninfo.host, ninfo.port, retry, timeout, next.get());
|
||||
if (!rc.OK()) {
|
||||
return Fail("Bootstrap failed to connect to ring next.", std::move(rc));
|
||||
}
|
||||
return rc;
|
||||
} << [&] {
|
||||
return next->NonBlocking(true);
|
||||
} << [&] {
|
||||
SockAddrV4 addr;
|
||||
return listener->Accept(prev.get(), &addr);
|
||||
} << [&] { return prev->NonBlocking(true); };
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
// exchange host name and port
|
||||
std::vector<std::int8_t> buffer(HOST_NAME_MAX * comm.World(), 0);
|
||||
auto s_buffer = common::Span{buffer.data(), buffer.size()};
|
||||
auto next_host = s_buffer.subspan(HOST_NAME_MAX * comm.Rank(), HOST_NAME_MAX);
|
||||
if (next_host.size() < ninfo.host.size()) {
|
||||
return Fail("Got an invalid host name.");
|
||||
}
|
||||
std::copy(ninfo.host.cbegin(), ninfo.host.cend(), next_host.begin());
|
||||
|
||||
auto prev_ch = std::make_shared<Channel>(comm, prev);
|
||||
auto next_ch = std::make_shared<Channel>(comm, next);
|
||||
|
||||
auto block = [&] {
|
||||
for (auto ch : {prev_ch, next_ch}) {
|
||||
auto rc = ch->Block();
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
return Success();
|
||||
};
|
||||
|
||||
rc = std::move(rc) << [&] {
|
||||
return cpu_impl::RingAllgather(comm, s_buffer, HOST_NAME_MAX, 0, prev_ch, next_ch);
|
||||
} << [&] { return block(); };
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to get host names from peers.", std::move(rc));
|
||||
}
|
||||
|
||||
std::vector<std::int32_t> peers_port(comm.World(), -1);
|
||||
peers_port[comm.Rank()] = ninfo.port;
|
||||
rc = std::move(rc) << [&] {
|
||||
auto s_ports = common::Span{reinterpret_cast<std::int8_t*>(peers_port.data()),
|
||||
peers_port.size() * sizeof(ninfo.port)};
|
||||
return cpu_impl::RingAllgather(comm, s_ports, sizeof(ninfo.port), 0, prev_ch, next_ch);
|
||||
} << [&] { return block(); };
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to get the port from peers.", std::move(rc));
|
||||
}
|
||||
|
||||
std::vector<proto::PeerInfo> peers(comm.World());
|
||||
for (auto r = 0; r < comm.World(); ++r) {
|
||||
auto nhost = s_buffer.subspan(HOST_NAME_MAX * r, HOST_NAME_MAX);
|
||||
auto nport = peers_port[r];
|
||||
auto nrank = BootstrapNext(r, comm.World());
|
||||
|
||||
peers[nrank] = {std::string{reinterpret_cast<char const*>(nhost.data())}, nport, nrank};
|
||||
}
|
||||
CHECK_EQ(peers[comm.Rank()].port, lport);
|
||||
for (auto const& p : peers) {
|
||||
CHECK_NE(p.port, -1);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<TCPSocket>>& workers = *out_workers;
|
||||
workers.resize(comm.World());
|
||||
|
||||
for (std::int32_t r = (comm.Rank() + 1); r < comm.World(); ++r) {
|
||||
auto const& peer = peers[r];
|
||||
std::shared_ptr<TCPSocket> worker{TCPSocket::CreatePtr(comm.Domain())};
|
||||
rc = std::move(rc)
|
||||
<< [&] { return Connect(peer.host, peer.port, retry, timeout, worker.get()); }
|
||||
<< [&] { return worker->RecvTimeout(timeout); };
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
auto rank = comm.Rank();
|
||||
auto n_bytes = worker->SendAll(&rank, sizeof(comm.Rank()));
|
||||
if (n_bytes != sizeof(comm.Rank())) {
|
||||
return Fail("Failed to send rank.");
|
||||
}
|
||||
workers[r] = std::move(worker);
|
||||
}
|
||||
|
||||
for (std::int32_t r = 0; r < comm.Rank(); ++r) {
|
||||
SockAddrV4 addr;
|
||||
auto peer = std::shared_ptr<TCPSocket>(TCPSocket::CreatePtr(comm.Domain()));
|
||||
rc = std::move(rc) << [&] { return listener->Accept(peer.get(), &addr); }
|
||||
<< [&] { return peer->RecvTimeout(timeout); };
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
std::int32_t rank{-1};
|
||||
auto n_bytes = peer->RecvAll(&rank, sizeof(rank));
|
||||
if (n_bytes != sizeof(comm.Rank())) {
|
||||
return Fail("Failed to recv rank.");
|
||||
}
|
||||
workers[rank] = std::move(peer);
|
||||
}
|
||||
|
||||
for (std::int32_t r = 0; r < comm.World(); ++r) {
|
||||
if (r == comm.Rank()) {
|
||||
continue;
|
||||
}
|
||||
CHECK(workers[r]);
|
||||
}
|
||||
|
||||
return Success();
|
||||
}
|
||||
|
||||
RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t retry, std::string task_id)
|
||||
: Comm{std::move(host), port, timeout, retry, std::move(task_id)} {
|
||||
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result RabitComm::Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
|
||||
std::string task_id) {
|
||||
TCPSocket tracker;
|
||||
std::int32_t world{-1};
|
||||
auto rc = ConnectTrackerImpl(this->TrackerInfo(), timeout, retry, task_id, &tracker, this->Rank(),
|
||||
world);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Bootstrap failed.", std::move(rc));
|
||||
}
|
||||
|
||||
this->domain_ = tracker.Domain();
|
||||
|
||||
// Start command
|
||||
TCPSocket listener = TCPSocket::Create(tracker.Domain());
|
||||
std::int32_t lport = listener.BindHost();
|
||||
listener.Listen();
|
||||
|
||||
// create worker for listening to error notice.
|
||||
auto domain = tracker.Domain();
|
||||
std::shared_ptr<TCPSocket> error_sock{TCPSocket::CreatePtr(domain)};
|
||||
auto eport = error_sock->BindHost();
|
||||
error_sock->Listen();
|
||||
error_worker_ = std::thread{[this, error_sock = std::move(error_sock)] {
|
||||
auto conn = error_sock->Accept();
|
||||
// On Windows accept returns an invalid socket after network is shutdown.
|
||||
if (conn.IsClosed()) {
|
||||
return;
|
||||
}
|
||||
LOG(WARNING) << "Another worker is running into error.";
|
||||
std::string scmd;
|
||||
conn.Recv(&scmd);
|
||||
auto jcmd = Json::Load(scmd);
|
||||
auto rc = this->Shutdown();
|
||||
if (!rc.OK()) {
|
||||
LOG(WARNING) << "Fail to shutdown worker:" << rc.Report();
|
||||
}
|
||||
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
|
||||
exit(-1);
|
||||
#else
|
||||
LOG(FATAL) << rc.Report();
|
||||
#endif
|
||||
}};
|
||||
error_worker_.detach();
|
||||
|
||||
proto::Start start;
|
||||
rc = std::move(rc) << [&] { return start.WorkerSend(lport, &tracker, eport); }
|
||||
<< [&] { return start.WorkerRecv(&tracker, &world); };
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
this->world_ = world;
|
||||
|
||||
// get ring neighbors
|
||||
std::string snext;
|
||||
tracker.Recv(&snext);
|
||||
auto jnext = Json::Load(StringView{snext});
|
||||
|
||||
proto::PeerInfo ninfo{jnext};
|
||||
|
||||
// get the rank of this worker
|
||||
this->rank_ = BootstrapPrev(ninfo.rank, world);
|
||||
this->tracker_.rank = rank_;
|
||||
|
||||
std::vector<std::shared_ptr<TCPSocket>> workers;
|
||||
rc = ConnectWorkers(*this, &listener, lport, ninfo, timeout, retry, &workers);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
CHECK(this->channels_.empty());
|
||||
for (auto& w : workers) {
|
||||
if (w) {
|
||||
w->SetNoDelay();
|
||||
rc = w->NonBlocking(true);
|
||||
}
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
this->channels_.emplace_back(std::make_shared<Channel>(*this, w));
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
RabitComm::~RabitComm() noexcept(false) {
|
||||
if (!IsDistributed()) {
|
||||
return;
|
||||
}
|
||||
auto rc = this->Shutdown();
|
||||
if (!rc.OK()) {
|
||||
LOG(WARNING) << rc.Report();
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] Result RabitComm::Shutdown() {
|
||||
TCPSocket tracker;
|
||||
return Success() << [&] {
|
||||
return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World());
|
||||
} << [&] {
|
||||
return this->Block();
|
||||
} << [&] {
|
||||
Json jcmd{Object{}};
|
||||
jcmd["cmd"] = Integer{static_cast<std::int32_t>(proto::CMD::kShutdown)};
|
||||
auto scmd = Json::Dump(jcmd);
|
||||
auto n_bytes = tracker.Send(scmd);
|
||||
if (n_bytes != scmd.size()) {
|
||||
return Fail("Faled to send cmd.");
|
||||
}
|
||||
return Success();
|
||||
};
|
||||
}
|
||||
|
||||
[[nodiscard]] Result RabitComm::LogTracker(std::string msg) const {
|
||||
TCPSocket out;
|
||||
proto::Print print;
|
||||
return Success() << [&] { return this->ConnectTracker(&out); }
|
||||
<< [&] { return print.WorkerSend(&out, msg); };
|
||||
}
|
||||
|
||||
[[nodiscard]] Result RabitComm::SignalError(Result const& res) {
|
||||
TCPSocket out;
|
||||
return Success() << [&] { return this->ConnectTracker(&out); }
|
||||
<< [&] { return proto::ErrorCMD{}.WorkerSend(&out, res); };
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
156
src/collective/comm.h
Normal file
156
src/collective/comm.h
Normal file
@@ -0,0 +1,156 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <chrono> // for seconds
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
#include <type_traits> // for remove_const_t
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "loop.h" // for Loop
|
||||
#include "protocol.h" // for PeerInfo
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
|
||||
inline constexpr std::int32_t DefaultTimeoutSec() { return 300; } // 5min
|
||||
inline constexpr std::int32_t DefaultRetry() { return 3; }
|
||||
|
||||
// indexing into the ring
|
||||
inline std::int32_t BootstrapNext(std::int32_t r, std::int32_t world) {
|
||||
auto nrank = (r + world + 1) % world;
|
||||
return nrank;
|
||||
}
|
||||
|
||||
inline std::int32_t BootstrapPrev(std::int32_t r, std::int32_t world) {
|
||||
auto nrank = (r + world - 1) % world;
|
||||
return nrank;
|
||||
}
|
||||
|
||||
class Channel;
|
||||
|
||||
/**
|
||||
* @brief Base communicator storing info about the tracker and other communicators.
|
||||
*/
|
||||
class Comm {
|
||||
protected:
|
||||
std::int32_t world_{1};
|
||||
std::int32_t rank_{0};
|
||||
std::chrono::seconds timeout_{DefaultTimeoutSec()};
|
||||
std::int32_t retry_{DefaultRetry()};
|
||||
|
||||
proto::PeerInfo tracker_;
|
||||
SockDomain domain_{SockDomain::kV4};
|
||||
std::thread error_worker_;
|
||||
std::string task_id_;
|
||||
std::vector<std::shared_ptr<Channel>> channels_;
|
||||
std::shared_ptr<Loop> loop_{new Loop{std::chrono::seconds{
|
||||
DefaultTimeoutSec()}}}; // fixme: require federated comm to have a timeout
|
||||
|
||||
public:
|
||||
Comm() = default;
|
||||
Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, std::int32_t retry,
|
||||
std::string task_id);
|
||||
virtual ~Comm() noexcept(false) {} // NOLINT
|
||||
|
||||
Comm(Comm const& that) = delete;
|
||||
Comm& operator=(Comm const& that) = delete;
|
||||
Comm(Comm&& that) = delete;
|
||||
Comm& operator=(Comm&& that) = delete;
|
||||
|
||||
[[nodiscard]] auto TrackerInfo() const { return tracker_; }
|
||||
[[nodiscard]] Result ConnectTracker(TCPSocket* out) const;
|
||||
[[nodiscard]] auto Domain() const { return domain_; }
|
||||
[[nodiscard]] auto Timeout() const { return timeout_; }
|
||||
|
||||
[[nodiscard]] auto Rank() const { return rank_; }
|
||||
[[nodiscard]] auto World() const { return world_; }
|
||||
[[nodiscard]] bool IsDistributed() const { return World() > 1; }
|
||||
void Submit(Loop::Op op) const { loop_->Submit(op); }
|
||||
[[nodiscard]] Result Block() const { return loop_->Block(); }
|
||||
|
||||
[[nodiscard]] virtual std::shared_ptr<Channel> Chan(std::int32_t rank) const {
|
||||
return channels_.at(rank);
|
||||
}
|
||||
[[nodiscard]] virtual bool IsFederated() const = 0;
|
||||
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;
|
||||
|
||||
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
|
||||
};
|
||||
|
||||
class RabitComm : public Comm {
|
||||
[[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
|
||||
std::string task_id);
|
||||
[[nodiscard]] Result Shutdown();
|
||||
|
||||
public:
|
||||
// bootstrapping construction.
|
||||
RabitComm() = default;
|
||||
// ctor for testing where environment is known.
|
||||
RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t retry, std::string task_id);
|
||||
~RabitComm() noexcept(false) override;
|
||||
|
||||
[[nodiscard]] bool IsFederated() const override { return false; }
|
||||
[[nodiscard]] Result LogTracker(std::string msg) const override;
|
||||
|
||||
[[nodiscard]] Result SignalError(Result const&) override;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Communication channel between workers.
|
||||
*/
|
||||
class Channel {
|
||||
std::shared_ptr<TCPSocket> sock_{nullptr};
|
||||
Result rc_;
|
||||
Comm const& comm_;
|
||||
|
||||
public:
|
||||
explicit Channel(Comm const& comm, std::shared_ptr<TCPSocket> sock)
|
||||
: sock_{std::move(sock)}, comm_{comm} {}
|
||||
|
||||
void SendAll(std::int8_t const* ptr, std::size_t n) {
|
||||
Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast<std::int8_t*>(ptr), n, sock_.get(), 0};
|
||||
CHECK(sock_.get());
|
||||
comm_.Submit(std::move(op));
|
||||
}
|
||||
void SendAll(common::Span<std::int8_t const> data) {
|
||||
this->SendAll(data.data(), data.size_bytes());
|
||||
}
|
||||
|
||||
void RecvAll(std::int8_t* ptr, std::size_t n) {
|
||||
Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0};
|
||||
CHECK(sock_.get());
|
||||
comm_.Submit(std::move(op));
|
||||
}
|
||||
void RecvAll(common::Span<std::int8_t> data) { this->RecvAll(data.data(), data.size_bytes()); }
|
||||
|
||||
[[nodiscard]] auto Socket() const { return sock_; }
|
||||
[[nodiscard]] Result Block() { return comm_.Block(); }
|
||||
};
|
||||
|
||||
enum class Op { kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5 };
|
||||
|
||||
template <typename T, typename U = std::conditional_t<std::is_const_v<T>,
|
||||
std::add_const_t<std::int8_t>, std::int8_t>>
|
||||
common::Span<U> EraseType(common::Span<T> data) {
|
||||
auto n_total_bytes = data.size_bytes();
|
||||
auto erased = common::Span{reinterpret_cast<std::add_pointer_t<U>>(data.data()), n_total_bytes};
|
||||
return erased;
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
common::Span<T> RestoreType(common::Span<U> data) {
|
||||
static_assert(std::is_same_v<std::remove_const_t<U>, std::int8_t>);
|
||||
auto n_total_bytes = data.size_bytes();
|
||||
auto restored = common::Span{reinterpret_cast<T*>(data.data()), n_total_bytes / sizeof(T)};
|
||||
return restored;
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@@ -57,9 +57,7 @@ namespace collective {
|
||||
* - federated_client_key: Client key file path. Only needed for the SSL mode.
|
||||
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
|
||||
*/
|
||||
inline void Init(Json const& config) {
|
||||
Communicator::Init(config);
|
||||
}
|
||||
inline void Init(Json const &config) { Communicator::Init(config); }
|
||||
|
||||
/*!
|
||||
* \brief Finalize the collective communicator.
|
||||
@@ -141,17 +139,89 @@ inline void Broadcast(std::string *sendrecv_data, int root) {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gathers a single value all processes and distributes the result to all processes.
|
||||
*
|
||||
* @param input The single value.
|
||||
*/
|
||||
template <typename T>
|
||||
inline std::vector<T> Allgather(T const &input) {
|
||||
std::string_view str_input{reinterpret_cast<char const *>(&input), sizeof(T)};
|
||||
auto const output = Communicator::Get()->AllGather(str_input);
|
||||
CHECK_EQ(output.size() % sizeof(T), 0);
|
||||
std::vector<T> result(output.size() / sizeof(T));
|
||||
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gathers data from all processes and distributes it to all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size, and input data has been sliced into the
|
||||
* corresponding position.
|
||||
* This assumes all ranks have the same size.
|
||||
*
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param size Size of the data in bytes.
|
||||
* @param input Buffer storing the data.
|
||||
*/
|
||||
inline void Allgather(void *send_receive_buffer, std::size_t size) {
|
||||
Communicator::Get()->AllGather(send_receive_buffer, size);
|
||||
template <typename T>
|
||||
inline std::vector<T> Allgather(std::vector<T> const &input) {
|
||||
if (input.empty()) {
|
||||
return input;
|
||||
}
|
||||
std::string_view str_input{reinterpret_cast<char const *>(input.data()),
|
||||
input.size() * sizeof(T)};
|
||||
auto const output = Communicator::Get()->AllGather(str_input);
|
||||
CHECK_EQ(output.size() % sizeof(T), 0);
|
||||
std::vector<T> result(output.size() / sizeof(T));
|
||||
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length data from all processes and distributes it to all processes.
|
||||
* @param input Buffer storing the data.
|
||||
*/
|
||||
template <typename T>
|
||||
inline std::vector<T> AllgatherV(std::vector<T> const &input) {
|
||||
std::string_view str_input{reinterpret_cast<char const *>(input.data()),
|
||||
input.size() * sizeof(T)};
|
||||
auto const output = Communicator::Get()->AllGatherV(str_input);
|
||||
CHECK_EQ(output.size() % sizeof(T), 0);
|
||||
std::vector<T> result(output.size() / sizeof(T));
|
||||
if (!output.empty()) {
|
||||
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length strings from all processes and distributes them to all processes.
|
||||
* @param input Variable-length list of variable-length strings.
|
||||
*/
|
||||
inline std::vector<std::string> AllgatherStrings(std::vector<std::string> const &input) {
|
||||
std::size_t total_size{0};
|
||||
for (auto const &s : input) {
|
||||
total_size += s.length() + 1; // +1 for null-terminators
|
||||
}
|
||||
std::string flat_string;
|
||||
flat_string.reserve(total_size);
|
||||
for (auto const &s : input) {
|
||||
flat_string.append(s);
|
||||
flat_string.push_back('\0'); // Append a null-terminator after each string
|
||||
}
|
||||
|
||||
auto const output = Communicator::Get()->AllGatherV(flat_string);
|
||||
|
||||
std::vector<std::string> result;
|
||||
std::size_t start_index = 0;
|
||||
// Iterate through the output, find each null-terminated substring.
|
||||
for (std::size_t i = 0; i < output.size(); i++) {
|
||||
if (output[i] == '\0') {
|
||||
// Construct a std::string from the char* substring
|
||||
result.emplace_back(&output[start_index]);
|
||||
// Move to the next substring
|
||||
start_index = i + 1;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/*!
|
||||
@@ -226,7 +296,7 @@ inline void Allreduce(double *send_receive_buffer, size_t count) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct AllgatherVResult {
|
||||
struct SpecialAllgatherVResult {
|
||||
std::vector<std::size_t> offsets;
|
||||
std::vector<std::size_t> sizes;
|
||||
std::vector<T> result;
|
||||
@@ -241,14 +311,10 @@ struct AllgatherVResult {
|
||||
* @param sizes Sizes of each input.
|
||||
*/
|
||||
template <typename T>
|
||||
inline AllgatherVResult<T> AllgatherV(std::vector<T> const &inputs,
|
||||
std::vector<std::size_t> const &sizes) {
|
||||
auto num_inputs = sizes.size();
|
||||
|
||||
inline SpecialAllgatherVResult<T> SpecialAllgatherV(std::vector<T> const &inputs,
|
||||
std::vector<std::size_t> const &sizes) {
|
||||
// Gather the sizes across all workers.
|
||||
std::vector<std::size_t> all_sizes(num_inputs * GetWorldSize());
|
||||
std::copy_n(sizes.cbegin(), sizes.size(), all_sizes.begin() + num_inputs * GetRank());
|
||||
collective::Allgather(all_sizes.data(), all_sizes.size() * sizeof(std::size_t));
|
||||
auto const all_sizes = Allgather(sizes);
|
||||
|
||||
// Calculate input offsets (std::exclusive_scan).
|
||||
std::vector<std::size_t> offsets(all_sizes.size());
|
||||
@@ -257,11 +323,7 @@ inline AllgatherVResult<T> AllgatherV(std::vector<T> const &inputs,
|
||||
}
|
||||
|
||||
// Gather all the inputs.
|
||||
auto total_input_size = offsets.back() + all_sizes.back();
|
||||
std::vector<T> all_inputs(total_input_size);
|
||||
std::copy_n(inputs.cbegin(), inputs.size(), all_inputs.begin() + offsets[num_inputs * GetRank()]);
|
||||
// We cannot use allgather here, since each worker might have a different size.
|
||||
Allreduce<Operation::kMax>(all_inputs.data(), all_inputs.size());
|
||||
auto const all_inputs = AllgatherV(inputs);
|
||||
|
||||
return {offsets, all_sizes, all_inputs};
|
||||
}
|
||||
|
||||
@@ -11,9 +11,7 @@
|
||||
#include "../../plugin/federated/federated_communicator.h"
|
||||
#endif
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
namespace xgboost::collective {
|
||||
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
|
||||
thread_local CommunicatorType Communicator::type_{};
|
||||
|
||||
@@ -57,6 +55,4 @@ void Communicator::Finalize() {
|
||||
communicator_.reset(new NoOpCommunicator());
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -125,13 +125,17 @@ class Communicator {
|
||||
/**
|
||||
* @brief Gathers data from all processes and distributes it to all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size, and input data has been sliced into the
|
||||
* corresponding position.
|
||||
* This assumes all ranks have the same size.
|
||||
*
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param size Size of the data in bytes.
|
||||
* @param input Buffer storing the data.
|
||||
*/
|
||||
virtual void AllGather(void *send_receive_buffer, std::size_t size) = 0;
|
||||
virtual std::string AllGather(std::string_view input) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length data from all processes and distributes it to all processes.
|
||||
* @param input Buffer storing the data.
|
||||
*/
|
||||
virtual std::string AllGatherV(std::string_view input) = 0;
|
||||
|
||||
/**
|
||||
* @brief Combines values from all processes and distributes the result back to all processes.
|
||||
|
||||
@@ -40,12 +40,10 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
host_buffer_.resize(send_size * world_size_);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + rank_ * send_size, send_buffer, send_size,
|
||||
cudaMemcpyDefault));
|
||||
Allgather(host_buffer_.data(), host_buffer_.size());
|
||||
dh::safe_cuda(
|
||||
cudaMemcpy(receive_buffer, host_buffer_.data(), host_buffer_.size(), cudaMemcpyDefault));
|
||||
host_buffer_.resize(send_size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_buffer, send_size, cudaMemcpyDefault));
|
||||
auto const output = Allgather(host_buffer_);
|
||||
dh::safe_cuda(cudaMemcpy(receive_buffer, output.data(), output.size(), cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
|
||||
@@ -60,11 +60,16 @@ class InMemoryCommunicator : public Communicator {
|
||||
bool IsDistributed() const override { return true; }
|
||||
bool IsFederated() const override { return false; }
|
||||
|
||||
void AllGather(void* in_out, std::size_t size) override {
|
||||
std::string AllGather(std::string_view input) override {
|
||||
std::string output;
|
||||
handler_.Allgather(static_cast<const char*>(in_out), size, &output, sequence_number_++,
|
||||
GetRank());
|
||||
output.copy(static_cast<char*>(in_out), size);
|
||||
handler_.Allgather(input.data(), input.size(), &output, sequence_number_++, GetRank());
|
||||
return output;
|
||||
}
|
||||
|
||||
std::string AllGatherV(std::string_view input) override {
|
||||
std::string output;
|
||||
handler_.AllgatherV(input.data(), input.size(), &output, sequence_number_++, GetRank());
|
||||
return output;
|
||||
}
|
||||
|
||||
void AllReduce(void* in_out, std::size_t size, DataType data_type, Operation operation) override {
|
||||
|
||||
@@ -16,23 +16,49 @@ class AllgatherFunctor {
|
||||
public:
|
||||
std::string const name{"Allgather"};
|
||||
|
||||
AllgatherFunctor(int world_size, int rank) : world_size_{world_size}, rank_{rank} {}
|
||||
AllgatherFunctor(std::size_t world_size, std::size_t rank)
|
||||
: world_size_{world_size}, rank_{rank} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
if (buffer->empty()) {
|
||||
// Copy the input if this is the first request.
|
||||
buffer->assign(input, bytes);
|
||||
} else {
|
||||
// Splice the input into the common buffer.
|
||||
auto const per_rank = bytes / world_size_;
|
||||
auto const index = rank_ * per_rank;
|
||||
buffer->replace(index, per_rank, input + index, per_rank);
|
||||
// Resize the buffer if this is the first request.
|
||||
buffer->resize(bytes * world_size_);
|
||||
}
|
||||
|
||||
// Splice the input into the common buffer.
|
||||
buffer->replace(rank_ * bytes, bytes, input, bytes);
|
||||
}
|
||||
|
||||
private:
|
||||
std::size_t world_size_;
|
||||
std::size_t rank_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Functor for variable-length allgather.
|
||||
*/
|
||||
class AllgatherVFunctor {
|
||||
public:
|
||||
std::string const name{"AllgatherV"};
|
||||
|
||||
AllgatherVFunctor(std::size_t world_size, std::size_t rank,
|
||||
std::map<std::size_t, std::string_view>* data)
|
||||
: world_size_{world_size}, rank_{rank}, data_{data} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
data_->emplace(rank_, std::string_view{input, bytes});
|
||||
if (data_->size() == world_size_) {
|
||||
for (auto const& kv : *data_) {
|
||||
buffer->append(kv.second);
|
||||
}
|
||||
data_->clear();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int world_size_;
|
||||
int rank_;
|
||||
std::size_t world_size_;
|
||||
std::size_t rank_;
|
||||
std::map<std::size_t, std::string_view>* data_;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -154,7 +180,7 @@ class BroadcastFunctor {
|
||||
public:
|
||||
std::string const name{"Broadcast"};
|
||||
|
||||
BroadcastFunctor(int rank, int root) : rank_{rank}, root_{root} {}
|
||||
BroadcastFunctor(std::size_t rank, std::size_t root) : rank_{rank}, root_{root} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
if (rank_ == root_) {
|
||||
@@ -164,11 +190,11 @@ class BroadcastFunctor {
|
||||
}
|
||||
|
||||
private:
|
||||
int rank_;
|
||||
int root_;
|
||||
std::size_t rank_;
|
||||
std::size_t root_;
|
||||
};
|
||||
|
||||
void InMemoryHandler::Init(int world_size, int) {
|
||||
void InMemoryHandler::Init(std::size_t world_size, std::size_t) {
|
||||
CHECK(world_size_ < world_size) << "In memory handler already initialized.";
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
@@ -178,7 +204,7 @@ void InMemoryHandler::Init(int world_size, int) {
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
||||
void InMemoryHandler::Shutdown(uint64_t sequence_number, int) {
|
||||
void InMemoryHandler::Shutdown(uint64_t sequence_number, std::size_t) {
|
||||
CHECK(world_size_ > 0) << "In memory handler already shutdown.";
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
@@ -194,24 +220,30 @@ void InMemoryHandler::Shutdown(uint64_t sequence_number, int) {
|
||||
}
|
||||
|
||||
void InMemoryHandler::Allgather(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank) {
|
||||
std::size_t sequence_number, std::size_t rank) {
|
||||
Handle(input, bytes, output, sequence_number, rank, AllgatherFunctor{world_size_, rank});
|
||||
}
|
||||
|
||||
void InMemoryHandler::AllgatherV(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank) {
|
||||
Handle(input, bytes, output, sequence_number, rank, AllgatherVFunctor{world_size_, rank, &aux_});
|
||||
}
|
||||
|
||||
void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank, DataType data_type,
|
||||
std::size_t sequence_number, std::size_t rank, DataType data_type,
|
||||
Operation op) {
|
||||
Handle(input, bytes, output, sequence_number, rank, AllreduceFunctor{data_type, op});
|
||||
}
|
||||
|
||||
void InMemoryHandler::Broadcast(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank, int root) {
|
||||
std::size_t sequence_number, std::size_t rank, std::size_t root) {
|
||||
Handle(input, bytes, output, sequence_number, rank, BroadcastFunctor{rank, root});
|
||||
}
|
||||
|
||||
template <class HandlerFunctor>
|
||||
void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank, HandlerFunctor const& functor) {
|
||||
std::size_t sequence_number, std::size_t rank,
|
||||
HandlerFunctor const& functor) {
|
||||
// Pass through if there is only 1 client.
|
||||
if (world_size_ == 1) {
|
||||
if (input != output->data()) {
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
#include <condition_variable>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "communicator.h"
|
||||
@@ -31,7 +32,7 @@ class InMemoryHandler {
|
||||
*
|
||||
* This is used when the handler only needs to be initialized once with a known world size.
|
||||
*/
|
||||
explicit InMemoryHandler(int worldSize) : world_size_{worldSize} {}
|
||||
explicit InMemoryHandler(std::size_t worldSize) : world_size_{worldSize} {}
|
||||
|
||||
/**
|
||||
* @brief Initialize the handler with the world size and rank.
|
||||
@@ -41,7 +42,7 @@ class InMemoryHandler {
|
||||
* This is used when multiple objects/threads are accessing the same handler and need to
|
||||
* initialize it collectively.
|
||||
*/
|
||||
void Init(int world_size, int rank);
|
||||
void Init(std::size_t world_size, std::size_t rank);
|
||||
|
||||
/**
|
||||
* @brief Shut down the handler.
|
||||
@@ -51,7 +52,7 @@ class InMemoryHandler {
|
||||
* This is used when multiple objects/threads are accessing the same handler and need to
|
||||
* shut it down collectively.
|
||||
*/
|
||||
void Shutdown(uint64_t sequence_number, int rank);
|
||||
void Shutdown(uint64_t sequence_number, std::size_t rank);
|
||||
|
||||
/**
|
||||
* @brief Perform allgather.
|
||||
@@ -62,7 +63,18 @@ class InMemoryHandler {
|
||||
* @param rank Index of the worker.
|
||||
*/
|
||||
void Allgather(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank);
|
||||
std::size_t sequence_number, std::size_t rank);
|
||||
|
||||
/**
|
||||
* @brief Perform variable-length allgather.
|
||||
* @param input The input buffer.
|
||||
* @param bytes Number of bytes in the input buffer.
|
||||
* @param output The output buffer.
|
||||
* @param sequence_number Call sequence number.
|
||||
* @param rank Index of the worker.
|
||||
*/
|
||||
void AllgatherV(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank);
|
||||
|
||||
/**
|
||||
* @brief Perform allreduce.
|
||||
@@ -75,7 +87,7 @@ class InMemoryHandler {
|
||||
* @param op The reduce operation.
|
||||
*/
|
||||
void Allreduce(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank, DataType data_type, Operation op);
|
||||
std::size_t sequence_number, std::size_t rank, DataType data_type, Operation op);
|
||||
|
||||
/**
|
||||
* @brief Perform broadcast.
|
||||
@@ -87,7 +99,7 @@ class InMemoryHandler {
|
||||
* @param root Index of the worker to broadcast from.
|
||||
*/
|
||||
void Broadcast(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank, int root);
|
||||
std::size_t sequence_number, std::size_t rank, std::size_t root);
|
||||
|
||||
private:
|
||||
/**
|
||||
@@ -102,15 +114,16 @@ class InMemoryHandler {
|
||||
*/
|
||||
template <class HandlerFunctor>
|
||||
void Handle(char const* input, std::size_t size, std::string* output, std::size_t sequence_number,
|
||||
int rank, HandlerFunctor const& functor);
|
||||
std::size_t rank, HandlerFunctor const& functor);
|
||||
|
||||
int world_size_{}; /// Number of workers.
|
||||
int received_{}; /// Number of calls received with the current sequence.
|
||||
int sent_{}; /// Number of calls completed with the current sequence.
|
||||
std::string buffer_{}; /// A shared common buffer.
|
||||
uint64_t sequence_number_{}; /// Call sequence number.
|
||||
mutable std::mutex mutex_; /// Lock.
|
||||
mutable std::condition_variable cv_; /// Conditional variable to wait on.
|
||||
std::size_t world_size_{}; /// Number of workers.
|
||||
std::size_t received_{}; /// Number of calls received with the current sequence.
|
||||
std::size_t sent_{}; /// Number of calls completed with the current sequence.
|
||||
std::string buffer_{}; /// A shared common buffer.
|
||||
std::map<std::size_t, std::string_view> aux_{}; /// A shared auxiliary map.
|
||||
uint64_t sequence_number_{}; /// Call sequence number.
|
||||
mutable std::mutex mutex_; /// Lock.
|
||||
mutable std::condition_variable cv_; /// Conditional variable to wait on.
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
|
||||
167
src/collective/loop.cc
Normal file
167
src/collective/loop.cc
Normal file
@@ -0,0 +1,167 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include "loop.h"
|
||||
|
||||
#include <queue> // for queue
|
||||
|
||||
#include "rabit/internal/socket.h" // for PollHelper
|
||||
#include "xgboost/collective/socket.h" // for FailWithCode
|
||||
#include "xgboost/logging.h" // for CHECK
|
||||
|
||||
namespace xgboost::collective {
|
||||
Result Loop::EmptyQueue() {
|
||||
timer_.Start(__func__);
|
||||
auto error = [this] {
|
||||
this->stop_ = true;
|
||||
timer_.Stop(__func__);
|
||||
};
|
||||
|
||||
while (!queue_.empty() && !stop_) {
|
||||
std::queue<Op> qcopy;
|
||||
rabit::utils::PollHelper poll;
|
||||
|
||||
// watch all ops
|
||||
while (!queue_.empty()) {
|
||||
auto op = queue_.front();
|
||||
queue_.pop();
|
||||
|
||||
switch (op.code) {
|
||||
case Op::kRead: {
|
||||
poll.WatchRead(*op.sock);
|
||||
break;
|
||||
}
|
||||
case Op::kWrite: {
|
||||
poll.WatchWrite(*op.sock);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
error();
|
||||
return Fail("Invalid socket operation.");
|
||||
}
|
||||
}
|
||||
qcopy.push(op);
|
||||
}
|
||||
|
||||
// poll, work on fds that are ready.
|
||||
timer_.Start("poll");
|
||||
auto rc = poll.Poll(timeout_);
|
||||
timer_.Stop("poll");
|
||||
if (!rc.OK()) {
|
||||
error();
|
||||
return rc;
|
||||
}
|
||||
// we wonldn't be here if the queue is empty.
|
||||
CHECK(!qcopy.empty());
|
||||
|
||||
while (!qcopy.empty() && !stop_) {
|
||||
auto op = qcopy.front();
|
||||
qcopy.pop();
|
||||
|
||||
std::int32_t n_bytes_done{0};
|
||||
CHECK(op.sock->NonBlocking());
|
||||
|
||||
switch (op.code) {
|
||||
case Op::kRead: {
|
||||
if (poll.CheckRead(*op.sock)) {
|
||||
n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Op::kWrite: {
|
||||
if (poll.CheckWrite(*op.sock)) {
|
||||
n_bytes_done = op.sock->Send(op.ptr + op.off, op.n - op.off);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
error();
|
||||
return Fail("Invalid socket operation.");
|
||||
}
|
||||
}
|
||||
|
||||
if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) {
|
||||
stop_ = true;
|
||||
auto rc = system::FailWithCode("Invalid socket output.");
|
||||
error();
|
||||
return rc;
|
||||
}
|
||||
op.off += n_bytes_done;
|
||||
CHECK_LE(op.off, op.n);
|
||||
|
||||
if (op.off != op.n) {
|
||||
// not yet finished, push back to queue for next round.
|
||||
queue_.push(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
timer_.Stop(__func__);
|
||||
return Success();
|
||||
}
|
||||
|
||||
void Loop::Process() {
|
||||
// consumer
|
||||
while (true) {
|
||||
std::unique_lock lock{mu_};
|
||||
cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; });
|
||||
if (stop_) {
|
||||
break;
|
||||
}
|
||||
CHECK(!mu_.try_lock());
|
||||
|
||||
this->rc_ = this->EmptyQueue();
|
||||
if (!rc_.OK()) {
|
||||
stop_ = true;
|
||||
cv_.notify_one();
|
||||
break;
|
||||
}
|
||||
|
||||
CHECK(queue_.empty());
|
||||
CHECK(!mu_.try_lock());
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
if (rc_.OK()) {
|
||||
CHECK(queue_.empty());
|
||||
}
|
||||
}
|
||||
|
||||
Result Loop::Stop() {
|
||||
std::unique_lock lock{mu_};
|
||||
stop_ = true;
|
||||
lock.unlock();
|
||||
|
||||
CHECK_EQ(this->Block().OK(), this->rc_.OK());
|
||||
|
||||
if (curr_exce_) {
|
||||
std::rethrow_exception(curr_exce_);
|
||||
}
|
||||
|
||||
return Success();
|
||||
}
|
||||
|
||||
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
|
||||
timer_.Init(__func__);
|
||||
worker_ = std::thread{[this] {
|
||||
try {
|
||||
this->Process();
|
||||
} catch (std::exception const& e) {
|
||||
std::lock_guard<std::mutex> guard{mu_};
|
||||
if (!curr_exce_) {
|
||||
curr_exce_ = std::current_exception();
|
||||
rc_ = Fail("Exception was thrown");
|
||||
}
|
||||
stop_ = true;
|
||||
cv_.notify_all();
|
||||
} catch (...) {
|
||||
std::lock_guard<std::mutex> guard{mu_};
|
||||
if (!curr_exce_) {
|
||||
curr_exce_ = std::current_exception();
|
||||
rc_ = Fail("Exception was thrown");
|
||||
}
|
||||
stop_ = true;
|
||||
cv_.notify_all();
|
||||
}
|
||||
}};
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
83
src/collective/loop.h
Normal file
83
src/collective/loop.h
Normal file
@@ -0,0 +1,83 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <chrono> // for seconds
|
||||
#include <condition_variable> // for condition_variable
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int8_t, int32_t
|
||||
#include <exception> // for exception_ptr
|
||||
#include <mutex> // for unique_lock, mutex
|
||||
#include <queue> // for queue
|
||||
#include <thread> // for thread
|
||||
#include <utility> // for move
|
||||
|
||||
#include "../common/timer.h" // for Monitor
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
||||
|
||||
namespace xgboost::collective {
|
||||
class Loop {
|
||||
public:
|
||||
struct Op {
|
||||
enum Code : std::int8_t { kRead = 0, kWrite = 1 } code;
|
||||
std::int32_t rank{-1};
|
||||
std::int8_t* ptr{nullptr};
|
||||
std::size_t n{0};
|
||||
TCPSocket* sock{nullptr};
|
||||
std::size_t off{0};
|
||||
|
||||
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
|
||||
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
|
||||
Op(Op const&) = default;
|
||||
Op& operator=(Op const&) = default;
|
||||
Op(Op&&) = default;
|
||||
Op& operator=(Op&&) = default;
|
||||
};
|
||||
|
||||
private:
|
||||
std::thread worker_;
|
||||
std::condition_variable cv_;
|
||||
std::mutex mu_;
|
||||
std::queue<Op> queue_;
|
||||
std::chrono::seconds timeout_;
|
||||
Result rc_;
|
||||
bool stop_{false};
|
||||
std::exception_ptr curr_exce_{nullptr};
|
||||
common::Monitor timer_;
|
||||
|
||||
Result EmptyQueue();
|
||||
void Process();
|
||||
|
||||
public:
|
||||
Result Stop();
|
||||
|
||||
void Submit(Op op) {
|
||||
// producer
|
||||
std::unique_lock lock{mu_};
|
||||
queue_.push(op);
|
||||
lock.unlock();
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Block() {
|
||||
{
|
||||
std::unique_lock lock{mu_};
|
||||
cv_.notify_all();
|
||||
}
|
||||
std::unique_lock lock{mu_};
|
||||
cv_.wait(lock, [this] { return this->queue_.empty() || stop_; });
|
||||
return std::move(rc_);
|
||||
}
|
||||
|
||||
explicit Loop(std::chrono::seconds timeout);
|
||||
|
||||
~Loop() noexcept(false) {
|
||||
this->Stop();
|
||||
|
||||
if (worker_.joinable()) {
|
||||
worker_.join();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
@@ -17,10 +17,11 @@ class NoOpCommunicator : public Communicator {
|
||||
NoOpCommunicator() : Communicator(1, 0) {}
|
||||
bool IsDistributed() const override { return false; }
|
||||
bool IsFederated() const override { return false; }
|
||||
void AllGather(void *, std::size_t) override {}
|
||||
std::string AllGather(std::string_view) override { return {}; }
|
||||
std::string AllGatherV(std::string_view) override { return {}; }
|
||||
void AllReduce(void *, std::size_t, DataType, Operation) override {}
|
||||
void Broadcast(void *, std::size_t, int) override {}
|
||||
std::string GetProcessorName() override { return ""; }
|
||||
std::string GetProcessorName() override { return {}; }
|
||||
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
|
||||
|
||||
protected:
|
||||
|
||||
214
src/collective/protocol.h
Normal file
214
src/collective/protocol.h
Normal file
@@ -0,0 +1,214 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstdint> // for int32_t
|
||||
#include <string> // for string
|
||||
#include <utility> // for move
|
||||
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective::proto {
|
||||
struct PeerInfo {
|
||||
std::string host;
|
||||
std::int32_t port{-1};
|
||||
std::int32_t rank{-1};
|
||||
|
||||
PeerInfo() = default;
|
||||
PeerInfo(std::string host, std::int32_t port, std::int32_t rank)
|
||||
: host{std::move(host)}, port{port}, rank{rank} {}
|
||||
|
||||
explicit PeerInfo(Json const& peer)
|
||||
: host{get<String>(peer["host"])},
|
||||
port{static_cast<std::int32_t>(get<Integer const>(peer["port"]))},
|
||||
rank{static_cast<std::int32_t>(get<Integer const>(peer["rank"]))} {}
|
||||
|
||||
[[nodiscard]] Json ToJson() const {
|
||||
Json info{Object{}};
|
||||
info["rank"] = rank;
|
||||
info["host"] = String{host};
|
||||
info["port"] = Integer{port};
|
||||
return info;
|
||||
}
|
||||
|
||||
[[nodiscard]] auto HostPort() const { return host + ":" + std::to_string(this->port); }
|
||||
};
|
||||
|
||||
struct Magic {
|
||||
static constexpr std::int32_t kMagic = 0xff99;
|
||||
|
||||
[[nodiscard]] Result Verify(xgboost::collective::TCPSocket* p_sock) {
|
||||
std::int32_t magic{kMagic};
|
||||
auto n_bytes = p_sock->SendAll(&magic, sizeof(magic));
|
||||
if (n_bytes != sizeof(magic)) {
|
||||
return Fail("Failed to verify.");
|
||||
}
|
||||
|
||||
magic = 0;
|
||||
n_bytes = p_sock->RecvAll(&magic, sizeof(magic));
|
||||
if (n_bytes != sizeof(magic)) {
|
||||
return Fail("Failed to verify.");
|
||||
}
|
||||
if (magic != kMagic) {
|
||||
return xgboost::collective::Fail("Invalid verification number.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
};
|
||||
|
||||
enum class CMD : std::int32_t {
|
||||
kInvalid = 0,
|
||||
kStart = 1,
|
||||
kShutdown = 2,
|
||||
kError = 3,
|
||||
kPrint = 4,
|
||||
};
|
||||
|
||||
struct Connect {
|
||||
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::int32_t world, std::int32_t rank,
|
||||
std::string task_id) const {
|
||||
Json jinit{Object{}};
|
||||
jinit["world_size"] = Integer{world};
|
||||
jinit["rank"] = Integer{rank};
|
||||
jinit["task_id"] = String{task_id};
|
||||
std::string msg;
|
||||
Json::Dump(jinit, &msg);
|
||||
auto n_bytes = tracker->Send(msg);
|
||||
if (n_bytes != msg.size()) {
|
||||
return Fail("Failed to send init command from worker.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
[[nodiscard]] Result TrackerRecv(TCPSocket* sock, std::int32_t* world, std::int32_t* rank,
|
||||
std::string* task_id) const {
|
||||
std::string init;
|
||||
sock->Recv(&init);
|
||||
auto jinit = Json::Load(StringView{init});
|
||||
*world = get<Integer const>(jinit["world_size"]);
|
||||
*rank = get<Integer const>(jinit["rank"]);
|
||||
*task_id = get<String const>(jinit["task_id"]);
|
||||
return Success();
|
||||
}
|
||||
};
|
||||
|
||||
class Start {
|
||||
private:
|
||||
[[nodiscard]] Result TrackerSend(std::int32_t world, TCPSocket* worker) const {
|
||||
Json jcmd{Object{}};
|
||||
jcmd["world_size"] = Integer{world};
|
||||
auto scmd = Json::Dump(jcmd);
|
||||
auto n_bytes = worker->Send(scmd);
|
||||
if (n_bytes != scmd.size()) {
|
||||
return Fail("Failed to send init command from tracker.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
|
||||
public:
|
||||
[[nodiscard]] Result WorkerSend(std::int32_t lport, TCPSocket* tracker,
|
||||
std::int32_t eport) const {
|
||||
Json jcmd{Object{}};
|
||||
jcmd["cmd"] = Integer{static_cast<std::int32_t>(CMD::kStart)};
|
||||
jcmd["port"] = Integer{lport};
|
||||
jcmd["error_port"] = Integer{eport};
|
||||
auto scmd = Json::Dump(jcmd);
|
||||
auto n_bytes = tracker->Send(scmd);
|
||||
if (n_bytes != scmd.size()) {
|
||||
return Fail("Failed to send init command from worker.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
[[nodiscard]] Result WorkerRecv(TCPSocket* tracker, std::int32_t* p_world) const {
|
||||
std::string scmd;
|
||||
auto n_bytes = tracker->Recv(&scmd);
|
||||
if (n_bytes <= 0) {
|
||||
return Fail("Failed to recv init command from tracker.");
|
||||
}
|
||||
auto jcmd = Json::Load(scmd);
|
||||
auto world = get<Integer const>(jcmd["world_size"]);
|
||||
if (world <= 0) {
|
||||
return Fail("Invalid world size.");
|
||||
}
|
||||
*p_world = world;
|
||||
return Success();
|
||||
}
|
||||
[[nodiscard]] Result TrackerHandle(Json jcmd, std::int32_t* recv_world, std::int32_t world,
|
||||
std::int32_t* p_port, TCPSocket* p_sock,
|
||||
std::int32_t* eport) const {
|
||||
*p_port = get<Integer const>(jcmd["port"]);
|
||||
if (*p_port <= 0) {
|
||||
return Fail("Invalid port.");
|
||||
}
|
||||
if (*recv_world != -1) {
|
||||
return Fail("Invalid initialization sequence.");
|
||||
}
|
||||
*recv_world = world;
|
||||
*eport = get<Integer const>(jcmd["error_port"]);
|
||||
return TrackerSend(world, p_sock);
|
||||
}
|
||||
};
|
||||
|
||||
struct Print {
|
||||
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::string msg) const {
|
||||
Json jcmd{Object{}};
|
||||
jcmd["cmd"] = Integer{static_cast<std::int32_t>(CMD::kPrint)};
|
||||
jcmd["msg"] = String{std::move(msg)};
|
||||
auto scmd = Json::Dump(jcmd);
|
||||
auto n_bytes = tracker->Send(scmd);
|
||||
if (n_bytes != scmd.size()) {
|
||||
return Fail("Failed to send print command from worker.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
[[nodiscard]] Result TrackerHandle(Json jcmd, std::string* p_msg) const {
|
||||
if (!IsA<String>(jcmd["msg"])) {
|
||||
return Fail("Invalid print command.");
|
||||
}
|
||||
auto msg = get<String const>(jcmd["msg"]);
|
||||
*p_msg = msg;
|
||||
return Success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ErrorCMD {
|
||||
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, Result const& res) const {
|
||||
auto msg = res.Report();
|
||||
auto code = res.Code().value();
|
||||
Json jcmd{Object{}};
|
||||
jcmd["msg"] = String{std::move(msg)};
|
||||
jcmd["code"] = Integer{code};
|
||||
jcmd["cmd"] = Integer{static_cast<std::int32_t>(CMD::kError)};
|
||||
auto scmd = Json::Dump(jcmd);
|
||||
auto n_bytes = tracker->Send(scmd);
|
||||
if (n_bytes != scmd.size()) {
|
||||
return Fail("Failed to send error command from worker.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
[[nodiscard]] Result TrackerHandle(Json jcmd, std::string* p_msg, int* p_code) const {
|
||||
if (!IsA<String>(jcmd["msg"]) || !IsA<Integer>(jcmd["code"])) {
|
||||
return Fail("Invalid error command.");
|
||||
}
|
||||
auto msg = get<String const>(jcmd["msg"]);
|
||||
auto code = get<Integer const>(jcmd["code"]);
|
||||
*p_msg = msg;
|
||||
*p_code = code;
|
||||
return Success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ShutdownCMD {
|
||||
[[nodiscard]] Result Send(TCPSocket* peer) const {
|
||||
Json jcmd{Object{}};
|
||||
jcmd["cmd"] = Integer{static_cast<std::int32_t>(proto::CMD::kShutdown)};
|
||||
auto scmd = Json::Dump(jcmd);
|
||||
auto n_bytes = peer->Send(scmd);
|
||||
if (n_bytes != scmd.size()) {
|
||||
return Fail("Failed to send shutdown command from worker.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::collective::proto
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator-inl.h"
|
||||
#include "communicator.h"
|
||||
#include "xgboost/json.h"
|
||||
|
||||
@@ -55,10 +56,29 @@ class RabitCommunicator : public Communicator {
|
||||
|
||||
bool IsFederated() const override { return false; }
|
||||
|
||||
void AllGather(void *send_receive_buffer, std::size_t size) override {
|
||||
auto const per_rank = size / GetWorldSize();
|
||||
std::string AllGather(std::string_view input) override {
|
||||
auto const per_rank = input.size();
|
||||
auto const total_size = per_rank * GetWorldSize();
|
||||
auto const index = per_rank * GetRank();
|
||||
rabit::Allgather(static_cast<char *>(send_receive_buffer), size, index, per_rank, per_rank);
|
||||
std::string result(total_size, '\0');
|
||||
result.replace(index, per_rank, input);
|
||||
rabit::Allgather(result.data(), total_size, index, per_rank, per_rank);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string AllGatherV(std::string_view input) override {
|
||||
auto const size_node_slice = input.size();
|
||||
auto const all_sizes = collective::Allgather(size_node_slice);
|
||||
auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul);
|
||||
auto const begin_index =
|
||||
std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul);
|
||||
auto const size_prev_slice =
|
||||
GetRank() == 0 ? all_sizes[GetWorldSize() - 1] : all_sizes[GetRank() - 1];
|
||||
|
||||
std::string result(total_size, '\0');
|
||||
result.replace(begin_index, size_node_slice, input);
|
||||
rabit::Allgather(result.data(), total_size, begin_index, size_node_slice, size_prev_slice);
|
||||
return result;
|
||||
}
|
||||
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
*/
|
||||
#include "xgboost/collective/socket.h"
|
||||
|
||||
#include <array> // for array
|
||||
#include <cstddef> // std::size_t
|
||||
#include <cstdint> // std::int32_t
|
||||
#include <cstring> // std::memcpy, std::memset
|
||||
@@ -92,13 +93,18 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
|
||||
|
||||
conn = TCPSocket::Create(addr.Domain());
|
||||
CHECK_EQ(static_cast<std::int32_t>(conn.Domain()), static_cast<std::int32_t>(addr.Domain()));
|
||||
conn.SetNonBlock(true);
|
||||
auto non_blocking = conn.NonBlocking();
|
||||
auto rc = conn.NonBlocking(true);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to set socket option.", std::move(rc));
|
||||
}
|
||||
|
||||
Result last_error;
|
||||
auto log_failure = [&host, &last_error](Result err, char const *file, std::int32_t line) {
|
||||
auto log_failure = [&host, &last_error, port](Result err, char const *file, std::int32_t line) {
|
||||
last_error = std::move(err);
|
||||
LOG(WARNING) << std::filesystem::path{file}.filename().string() << "(" << line
|
||||
<< "): Failed to connect to:" << host << " Error:" << last_error.Report();
|
||||
<< "): Failed to connect to:" << host << ":" << port
|
||||
<< " Error:" << last_error.Report();
|
||||
};
|
||||
|
||||
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
|
||||
@@ -112,39 +118,42 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
|
||||
}
|
||||
|
||||
auto rc = connect(conn.Handle(), addr_handle, addr_len);
|
||||
if (rc != 0) {
|
||||
auto errcode = system::LastError();
|
||||
if (!system::ErrorWouldBlock(errcode)) {
|
||||
log_failure(Fail("connect failed.", std::error_code{errcode, std::system_category()}),
|
||||
__FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
|
||||
rabit::utils::PollHelper poll;
|
||||
poll.WatchWrite(conn);
|
||||
auto result = poll.Poll(timeout);
|
||||
if (!result.OK()) {
|
||||
log_failure(std::move(result), __FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
if (!poll.CheckWrite(conn)) {
|
||||
log_failure(Fail("poll failed.", std::error_code{errcode, std::system_category()}),
|
||||
__FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
result = conn.GetSockError();
|
||||
if (!result.OK()) {
|
||||
log_failure(std::move(result), __FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
|
||||
conn.SetNonBlock(false);
|
||||
return Success();
|
||||
|
||||
} else {
|
||||
conn.SetNonBlock(false);
|
||||
return Success();
|
||||
if (rc == 0) {
|
||||
return conn.NonBlocking(non_blocking);
|
||||
}
|
||||
|
||||
auto errcode = system::LastError();
|
||||
if (!system::ErrorWouldBlock(errcode)) {
|
||||
log_failure(Fail("connect failed.", std::error_code{errcode, std::system_category()}),
|
||||
__FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
|
||||
rabit::utils::PollHelper poll;
|
||||
poll.WatchWrite(conn);
|
||||
auto result = poll.Poll(timeout);
|
||||
if (!result.OK()) {
|
||||
// poll would fail if there's a socket error, we log the root cause instead of the
|
||||
// poll failure.
|
||||
auto sockerr = conn.GetSockError();
|
||||
if (!sockerr.OK()) {
|
||||
result = std::move(sockerr);
|
||||
}
|
||||
log_failure(std::move(result), __FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
if (!poll.CheckWrite(conn)) {
|
||||
log_failure(Fail("poll failed.", std::error_code{errcode, std::system_category()}), __FILE__,
|
||||
__LINE__);
|
||||
continue;
|
||||
}
|
||||
result = conn.GetSockError();
|
||||
if (!result.OK()) {
|
||||
log_failure(std::move(result), __FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
|
||||
return conn.NonBlocking(non_blocking);
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
@@ -152,4 +161,13 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
|
||||
conn.Close();
|
||||
return Fail(ss.str(), std::move(last_error));
|
||||
}
|
||||
|
||||
[[nodiscard]] Result GetHostName(std::string *p_out) {
|
||||
std::array<char, HOST_NAME_MAX> buf;
|
||||
if (gethostname(&buf[0], HOST_NAME_MAX) != 0) {
|
||||
return system::FailWithCode("Failed to get host name.");
|
||||
}
|
||||
*p_out = buf.data();
|
||||
return Success();
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
296
src/collective/tracker.cc
Normal file
296
src/collective/tracker.cc
Normal file
@@ -0,0 +1,296 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#if defined(__unix__) || defined(__APPLE__)
|
||||
#include <netdb.h> // gethostbyname
|
||||
#include <sys/socket.h> // socket, AF_INET6, AF_INET, connect, getsockname
|
||||
#endif // defined(__unix__) || defined(__APPLE__)
|
||||
|
||||
#if !defined(NOMINMAX) && defined(_WIN32)
|
||||
#define NOMINMAX
|
||||
#endif // !defined(NOMINMAX)
|
||||
|
||||
#if defined(_WIN32)
|
||||
#include <winsock2.h>
|
||||
#include <ws2tcpip.h>
|
||||
#endif // defined(_WIN32)
|
||||
|
||||
#include <algorithm> // for sort
|
||||
#include <chrono> // for seconds
|
||||
#include <cstdint> // for int32_t
|
||||
#include <string> // for string
|
||||
#include <utility> // for move, forward
|
||||
|
||||
#include "../common/json_utils.h"
|
||||
#include "comm.h"
|
||||
#include "protocol.h" // for kMagic, PeerInfo
|
||||
#include "tracker.h"
|
||||
#include "xgboost/collective/result.h" // for Result, Fail, Success
|
||||
#include "xgboost/collective/socket.h" // for GetHostName, FailWithCode, MakeSockAddress, ...
|
||||
#include "xgboost/json.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
Tracker::Tracker(Json const& config)
|
||||
: n_workers_{static_cast<std::int32_t>(
|
||||
RequiredArg<Integer const>(config, "n_workers", __func__))},
|
||||
port_{static_cast<std::int32_t>(OptionalArg<Integer const>(config, "port", Integer::Int{0}))},
|
||||
timeout_{std::chrono::seconds{OptionalArg<Integer const>(
|
||||
config, "timeout", static_cast<std::int64_t>(collective::DefaultTimeoutSec()))}} {}
|
||||
|
||||
RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr)
|
||||
: sock_{std::move(sock)} {
|
||||
auto host = addr.Addr();
|
||||
|
||||
std::int32_t rank{0};
|
||||
rc_ = Success()
|
||||
<< [&] { return proto::Magic{}.Verify(&sock_); }
|
||||
<< [&] { return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_); };
|
||||
if (!rc_.OK()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::string cmd;
|
||||
sock_.Recv(&cmd);
|
||||
auto jcmd = Json::Load(StringView{cmd});
|
||||
cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"]));
|
||||
std::int32_t port{0};
|
||||
if (cmd_ == proto::CMD::kStart) {
|
||||
proto::Start start;
|
||||
rc_ = start.TrackerHandle(jcmd, &world_, world, &port, &sock_, &eport_);
|
||||
} else if (cmd_ == proto::CMD::kPrint) {
|
||||
proto::Print print;
|
||||
rc_ = print.TrackerHandle(jcmd, &msg_);
|
||||
} else if (cmd_ == proto::CMD::kError) {
|
||||
proto::ErrorCMD error;
|
||||
rc_ = error.TrackerHandle(jcmd, &msg_, &code_);
|
||||
}
|
||||
if (!rc_.OK()) {
|
||||
return;
|
||||
}
|
||||
|
||||
info_ = proto::PeerInfo{host, port, rank};
|
||||
}
|
||||
|
||||
RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
|
||||
std::string self;
|
||||
auto rc = collective::GetHostAddress(&self);
|
||||
auto host = OptionalArg<String>(config, "host", self);
|
||||
|
||||
listener_ = TCPSocket::Create(SockDomain::kV4);
|
||||
rc = listener_.Bind(host, &this->port_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
listener_.Listen();
|
||||
}
|
||||
|
||||
Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
auto& workers = *p_workers;
|
||||
|
||||
std::sort(workers.begin(), workers.end(), WorkerCmp{});
|
||||
|
||||
std::vector<std::thread> bootstrap_threads;
|
||||
for (std::int32_t r = 0; r < n_workers_; ++r) {
|
||||
auto& worker = workers[r];
|
||||
auto next = BootstrapNext(r, n_workers_);
|
||||
auto const& next_w = workers[next];
|
||||
bootstrap_threads.emplace_back([next, &worker, &next_w] {
|
||||
auto jnext = proto::PeerInfo{next_w.Host(), next_w.Port(), next}.ToJson();
|
||||
std::string str;
|
||||
Json::Dump(jnext, &str);
|
||||
worker.Send(StringView{str});
|
||||
});
|
||||
}
|
||||
|
||||
for (auto& t : bootstrap_threads) {
|
||||
t.join();
|
||||
}
|
||||
|
||||
for (auto const& w : workers) {
|
||||
worker_error_handles_.emplace_back(w.Host(), w.ErrorPort());
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::future<Result> RabitTracker::Run() {
|
||||
// a state machine to keep track of consistency.
|
||||
struct State {
|
||||
std::int32_t const n_workers;
|
||||
|
||||
std::int32_t n_shutdown{0};
|
||||
bool during_restart{false};
|
||||
std::vector<WorkerProxy> pending;
|
||||
|
||||
explicit State(std::int32_t world) : n_workers{world} {}
|
||||
State(State const& that) = delete;
|
||||
State& operator=(State&& that) = delete;
|
||||
|
||||
void Start(WorkerProxy&& worker) {
|
||||
CHECK_LT(pending.size(), n_workers);
|
||||
CHECK_LE(n_shutdown, n_workers);
|
||||
|
||||
pending.emplace_back(std::forward<WorkerProxy>(worker));
|
||||
|
||||
CHECK_LE(pending.size(), n_workers);
|
||||
}
|
||||
void Shutdown() {
|
||||
CHECK_GE(n_shutdown, 0);
|
||||
CHECK_LT(n_shutdown, n_workers);
|
||||
|
||||
++n_shutdown;
|
||||
|
||||
CHECK_LE(n_shutdown, n_workers);
|
||||
}
|
||||
void Error() {
|
||||
CHECK_LE(pending.size(), n_workers);
|
||||
CHECK_LE(n_shutdown, n_workers);
|
||||
|
||||
during_restart = true;
|
||||
}
|
||||
[[nodiscard]] bool Ready() const {
|
||||
CHECK_LE(pending.size(), n_workers);
|
||||
return static_cast<std::int32_t>(pending.size()) == n_workers;
|
||||
}
|
||||
void Bootstrap() {
|
||||
CHECK_EQ(pending.size(), n_workers);
|
||||
CHECK_LE(n_shutdown, n_workers);
|
||||
|
||||
// A reset.
|
||||
n_shutdown = 0;
|
||||
during_restart = false;
|
||||
pending.clear();
|
||||
}
|
||||
[[nodiscard]] bool ShouldContinue() const {
|
||||
CHECK_LE(pending.size(), n_workers);
|
||||
CHECK_LE(n_shutdown, n_workers);
|
||||
// - Without error, we should shutdown after all workers are offline.
|
||||
// - With error, all workers are offline, and we have during_restart as true.
|
||||
return n_shutdown != n_workers || during_restart;
|
||||
}
|
||||
};
|
||||
|
||||
return std::async(std::launch::async, [this] {
|
||||
State state{this->n_workers_};
|
||||
|
||||
while (state.ShouldContinue()) {
|
||||
TCPSocket sock;
|
||||
SockAddrV4 addr;
|
||||
auto rc = listener_.Accept(&sock, &addr);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to accept connection.", std::move(rc));
|
||||
}
|
||||
|
||||
auto worker = WorkerProxy{n_workers_, std::move(sock), std::move(addr)};
|
||||
if (!worker.Status().OK()) {
|
||||
return Fail("Failed to initialize worker proxy.", std::move(worker.Status()));
|
||||
}
|
||||
switch (worker.Command()) {
|
||||
case proto::CMD::kStart: {
|
||||
state.Start(std::move(worker));
|
||||
if (state.Ready()) {
|
||||
rc = this->Bootstrap(&state.pending);
|
||||
state.Bootstrap();
|
||||
}
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
case proto::CMD::kShutdown: {
|
||||
state.Shutdown();
|
||||
continue;
|
||||
}
|
||||
case proto::CMD::kError: {
|
||||
if (state.during_restart) {
|
||||
continue;
|
||||
}
|
||||
state.Error();
|
||||
auto msg = worker.Msg();
|
||||
auto code = worker.Code();
|
||||
LOG(WARNING) << "Recieved error from [" << worker.Host() << ":" << worker.Rank()
|
||||
<< "]: " << msg << " code:" << code;
|
||||
auto host = worker.Host();
|
||||
// We signal all workers for the error, if they haven't aborted already.
|
||||
for (auto& w : worker_error_handles_) {
|
||||
if (w.first == host) {
|
||||
continue;
|
||||
}
|
||||
TCPSocket out;
|
||||
// retry is set to 1, just let the worker timeout or error. Otherwise the
|
||||
// tracker and the worker might be waiting for each other.
|
||||
auto rc = Connect(w.first, w.second, 1, timeout_, &out);
|
||||
// send signal to stop the worker.
|
||||
proto::ShutdownCMD shutdown;
|
||||
rc = shutdown.Send(&out);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to inform workers to stop.");
|
||||
}
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
case proto::CMD::kPrint: {
|
||||
LOG(CONSOLE) << worker.Msg();
|
||||
continue;
|
||||
}
|
||||
case proto::CMD::kInvalid:
|
||||
default: {
|
||||
return Fail("Invalid command received.");
|
||||
}
|
||||
}
|
||||
}
|
||||
return Success();
|
||||
});
|
||||
}
|
||||
|
||||
[[nodiscard]] Result GetHostAddress(std::string* out) {
|
||||
auto rc = GetHostName(out);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
auto host = gethostbyname(out->c_str());
|
||||
|
||||
// get ip address from host
|
||||
std::string ip;
|
||||
rc = INetNToP(host, &ip);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
if (!(ip.size() >= 4 && ip.substr(0, 4) == "127.")) {
|
||||
// return if this is a public IP address.
|
||||
// not entirely accurate, we have other reserved IPs
|
||||
*out = ip;
|
||||
return Success();
|
||||
}
|
||||
|
||||
// Create an UDP socket to prob the public IP address, it's fine even if it's
|
||||
// unreachable.
|
||||
auto sock = socket(AF_INET, SOCK_DGRAM, 0);
|
||||
if (sock == -1) {
|
||||
return Fail("Failed to create socket.");
|
||||
}
|
||||
|
||||
auto paddr = MakeSockAddress(StringView{"10.255.255.255"}, 1);
|
||||
sockaddr const* addr_handle = reinterpret_cast<const sockaddr*>(&paddr.V4().Handle());
|
||||
socklen_t addr_len{sizeof(paddr.V4().Handle())};
|
||||
auto err = connect(sock, addr_handle, addr_len);
|
||||
if (err != 0) {
|
||||
return system::FailWithCode("Failed to find IP address.");
|
||||
}
|
||||
|
||||
// get the IP address from socket desrciptor
|
||||
struct sockaddr_in addr;
|
||||
socklen_t len = sizeof(addr);
|
||||
if (getsockname(sock, reinterpret_cast<struct sockaddr*>(&addr), &len) == -1) {
|
||||
return Fail("Failed to get sock name.");
|
||||
}
|
||||
ip = inet_ntoa(addr.sin_addr);
|
||||
|
||||
err = system::CloseSocket(sock);
|
||||
if (err != 0) {
|
||||
return system::FailWithCode("Failed to close socket.");
|
||||
}
|
||||
|
||||
*out = ip;
|
||||
return Success();
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
141
src/collective/tracker.h
Normal file
141
src/collective/tracker.h
Normal file
@@ -0,0 +1,141 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <chrono> // for seconds
|
||||
#include <cstdint> // for int32_t
|
||||
#include <future> // for future
|
||||
#include <string> // for string
|
||||
#include <utility> // for pair
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "protocol.h"
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
/**
|
||||
*
|
||||
* @brief Implementation of RABIT tracker.
|
||||
*
|
||||
* * What is a tracker
|
||||
*
|
||||
* The implementation of collective follows what RABIT did in the past. It requires a
|
||||
* tracker to coordinate initialization and error recovery of workers. While the
|
||||
* original implementation attempted to attain error resislient inside the collective
|
||||
* module, which turned out be too challenging due to large amount of external
|
||||
* states. The new implementation here differs from RABIT in the way that neither state
|
||||
* recovery nor resislient is handled inside the collective, it merely provides the
|
||||
* mechanism to signal error to other workers through the use of a centralized tracker.
|
||||
*
|
||||
* There are three major functionalities provided the a tracker, namely:
|
||||
* - Initialization. Share the node addresses among all workers.
|
||||
* - Logging.
|
||||
* - Signal error. If an exception is thrown in one (or many) of the workers, it can
|
||||
* signal an error to the tracker and the tracker will notify other workers.
|
||||
*/
|
||||
class Tracker {
|
||||
protected:
|
||||
std::int32_t n_workers_{0};
|
||||
std::int32_t port_{-1};
|
||||
std::chrono::seconds timeout_{0};
|
||||
|
||||
public:
|
||||
explicit Tracker(Json const& config);
|
||||
Tracker(std::int32_t n_worders, std::int32_t port, std::chrono::seconds timeout)
|
||||
: n_workers_{n_worders}, port_{port}, timeout_{timeout} {}
|
||||
|
||||
virtual ~Tracker() noexcept(false){}; // NOLINT
|
||||
[[nodiscard]] virtual std::future<Result> Run() = 0;
|
||||
[[nodiscard]] virtual Json WorkerArgs() const = 0;
|
||||
[[nodiscard]] std::chrono::seconds Timeout() const { return timeout_; }
|
||||
};
|
||||
|
||||
class RabitTracker : public Tracker {
|
||||
// a wrapper for connected worker socket.
|
||||
class WorkerProxy {
|
||||
TCPSocket sock_;
|
||||
proto::PeerInfo info_;
|
||||
std::int32_t eport_{0};
|
||||
std::int32_t world_{-1};
|
||||
std::string task_id_;
|
||||
|
||||
proto::CMD cmd_{proto::CMD::kInvalid};
|
||||
std::string msg_;
|
||||
std::int32_t code_{0};
|
||||
Result rc_;
|
||||
|
||||
public:
|
||||
explicit WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr);
|
||||
WorkerProxy(WorkerProxy const& that) = delete;
|
||||
WorkerProxy(WorkerProxy&& that) = default;
|
||||
WorkerProxy& operator=(WorkerProxy const&) = delete;
|
||||
WorkerProxy& operator=(WorkerProxy&&) = default;
|
||||
|
||||
[[nodiscard]] auto Host() const { return info_.host; }
|
||||
[[nodiscard]] auto TaskID() const { return task_id_; }
|
||||
[[nodiscard]] auto Port() const { return info_.port; }
|
||||
[[nodiscard]] auto Rank() const { return info_.rank; }
|
||||
[[nodiscard]] auto ErrorPort() const { return eport_; }
|
||||
[[nodiscard]] auto Command() const { return cmd_; }
|
||||
[[nodiscard]] auto Msg() const { return msg_; }
|
||||
[[nodiscard]] auto Code() const { return code_; }
|
||||
|
||||
[[nodiscard]] Result const& Status() const { return rc_; }
|
||||
[[nodiscard]] Result& Status() { return rc_; }
|
||||
|
||||
void Send(StringView value) { this->sock_.Send(value); }
|
||||
};
|
||||
// provide an ordering for workers, this helps us get deterministic topology.
|
||||
struct WorkerCmp {
|
||||
[[nodiscard]] bool operator()(WorkerProxy const& lhs, WorkerProxy const& rhs) {
|
||||
auto const& lh = lhs.Host();
|
||||
auto const& rh = rhs.Host();
|
||||
|
||||
if (lh != rh) {
|
||||
return lh < rh;
|
||||
}
|
||||
return lhs.TaskID() < rhs.TaskID();
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
std::string host_;
|
||||
// record for how to reach out to workers if error happens.
|
||||
std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_;
|
||||
// listening socket for incoming workers.
|
||||
TCPSocket listener_;
|
||||
|
||||
Result Bootstrap(std::vector<WorkerProxy>* p_workers);
|
||||
|
||||
public:
|
||||
explicit RabitTracker(StringView host, std::int32_t n_worders, std::int32_t port,
|
||||
std::chrono::seconds timeout)
|
||||
: Tracker{n_worders, port, timeout}, host_{host.c_str(), host.size()} {
|
||||
listener_ = TCPSocket::Create(SockDomain::kV4);
|
||||
auto rc = listener_.Bind(host, &this->port_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
listener_.Listen();
|
||||
}
|
||||
|
||||
explicit RabitTracker(Json const& config);
|
||||
~RabitTracker() noexcept(false) override = default;
|
||||
|
||||
std::future<Result> Run() override;
|
||||
|
||||
[[nodiscard]] std::int32_t Port() const { return port_; }
|
||||
[[nodiscard]] Json WorkerArgs() const override {
|
||||
Json args{Object{}};
|
||||
args["DMLC_TRACKER_URI"] = String{host_};
|
||||
args["DMLC_TRACKER_PORT"] = this->Port();
|
||||
return args;
|
||||
}
|
||||
};
|
||||
|
||||
// Prob the public IP address of the host, need a better method.
|
||||
//
|
||||
// This is directly translated from the previous Python implementation, we should find a
|
||||
// more riguous approach, can use some expertise in network programming.
|
||||
[[nodiscard]] Result GetHostAddress(std::string* out);
|
||||
} // namespace xgboost::collective
|
||||
Reference in New Issue
Block a user