merge latest changes
This commit is contained in:
@@ -26,18 +26,19 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
||||
}
|
||||
|
||||
for (std::int32_t r = 0; r < world; ++r) {
|
||||
auto send_rank = (rank + world - r + worker_off) % world;
|
||||
auto send_off = send_rank * segment_size;
|
||||
send_off = std::min(send_off, data.size_bytes());
|
||||
auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off));
|
||||
next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
|
||||
|
||||
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
|
||||
auto recv_off = recv_rank * segment_size;
|
||||
recv_off = std::min(recv_off, data.size_bytes());
|
||||
auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off));
|
||||
prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||
auto rc = prev_ch->Block();
|
||||
auto rc = Success() << [&] {
|
||||
auto send_rank = (rank + world - r + worker_off) % world;
|
||||
auto send_off = send_rank * segment_size;
|
||||
send_off = std::min(send_off, data.size_bytes());
|
||||
auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off));
|
||||
return next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
|
||||
} << [&] {
|
||||
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
|
||||
auto recv_off = recv_rank * segment_size;
|
||||
recv_off = std::min(recv_off, data.size_bytes());
|
||||
auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off));
|
||||
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||
} << [&] { return prev_ch->Block(); };
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
@@ -78,19 +79,19 @@ namespace detail {
|
||||
auto next_ch = comm.Chan(next);
|
||||
|
||||
for (std::int32_t r = 0; r < world; ++r) {
|
||||
auto send_rank = (rank + world - r) % world;
|
||||
auto send_off = offset[send_rank];
|
||||
auto send_size = sizes[send_rank];
|
||||
auto send_seg = erased_result.subspan(send_off, send_size);
|
||||
next_ch->SendAll(send_seg);
|
||||
|
||||
auto recv_rank = (rank + world - r - 1) % world;
|
||||
auto recv_off = offset[recv_rank];
|
||||
auto recv_size = sizes[recv_rank];
|
||||
auto recv_seg = erased_result.subspan(recv_off, recv_size);
|
||||
prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||
|
||||
auto rc = prev_ch->Block();
|
||||
auto rc = Success() << [&] {
|
||||
auto send_rank = (rank + world - r) % world;
|
||||
auto send_off = offset[send_rank];
|
||||
auto send_size = sizes[send_rank];
|
||||
auto send_seg = erased_result.subspan(send_off, send_size);
|
||||
return next_ch->SendAll(send_seg);
|
||||
} << [&] {
|
||||
auto recv_rank = (rank + world - r - 1) % world;
|
||||
auto recv_off = offset[recv_rank];
|
||||
auto recv_size = sizes[recv_rank];
|
||||
auto recv_seg = erased_result.subspan(recv_off, recv_size);
|
||||
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||
} << [&] { return prev_ch->Block(); };
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <algorithm> // for min
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t, int8_t
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../data/array_interface.h" // for Type, DispatchDType
|
||||
@@ -36,7 +37,10 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||
auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg);
|
||||
auto send_seg = data.subspan(send_off, seg_nbytes);
|
||||
|
||||
next_ch->SendAll(send_seg);
|
||||
auto rc = next_ch->SendAll(send_seg);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
// receive from ring prev
|
||||
auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg;
|
||||
@@ -46,8 +50,7 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||
auto recv_seg = data.subspan(recv_off, seg_nbytes);
|
||||
auto seg = s_buf.subspan(0, recv_seg.size());
|
||||
|
||||
prev_ch->RecvAll(seg);
|
||||
auto rc = prev_ch->Block();
|
||||
rc = std::move(rc) << [&] { return prev_ch->RecvAll(seg); } << [&] { return comm.Block(); };
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
@@ -62,6 +65,9 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||
|
||||
Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func const& op,
|
||||
ArrayInterfaceHandler::Type type) {
|
||||
if (comm.World() == 1) {
|
||||
return Success();
|
||||
}
|
||||
return DispatchDType(type, [&](auto t) {
|
||||
using T = decltype(t);
|
||||
// Divide the data into segments according to the number of workers.
|
||||
@@ -80,11 +86,9 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
|
||||
auto prev_ch = comm.Chan(prev);
|
||||
auto next_ch = comm.Chan(next);
|
||||
|
||||
rc = RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
return comm.Block();
|
||||
return std::move(rc) << [&] {
|
||||
return RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
|
||||
} << [&] { return comm.Block(); };
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective::cpu_impl
|
||||
|
||||
@@ -62,8 +62,8 @@ Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t
|
||||
|
||||
if (shifted_rank != 0) { // not root
|
||||
auto parent = ShiftRight(ShiftedParentRank(shifted_rank, depth), world, root);
|
||||
comm.Chan(parent)->RecvAll(data);
|
||||
auto rc = comm.Chan(parent)->Block();
|
||||
auto rc = Success() << [&] { return comm.Chan(parent)->RecvAll(data); }
|
||||
<< [&] { return comm.Chan(parent)->Block(); };
|
||||
if (!rc.OK()) {
|
||||
return Fail("broadcast failed.", std::move(rc));
|
||||
}
|
||||
@@ -75,7 +75,10 @@ Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t
|
||||
auto sft_peer = shifted_rank + (1 << i);
|
||||
auto peer = ShiftRight(sft_peer, world, root);
|
||||
CHECK_NE(peer, root);
|
||||
comm.Chan(peer)->SendAll(data);
|
||||
auto rc = comm.Chan(peer)->SendAll(data);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -23,25 +23,6 @@ Coll* Coll::MakeCUDAVar() { return new NCCLColl{}; }
|
||||
|
||||
NCCLColl::~NCCLColl() = default;
|
||||
namespace {
|
||||
Result GetNCCLResult(ncclResult_t code) {
|
||||
if (code == ncclSuccess) {
|
||||
return Success();
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "NCCL failure: " << ncclGetErrorString(code) << ".";
|
||||
if (code == ncclUnhandledCudaError) {
|
||||
// nccl usually preserves the last error so we can get more details.
|
||||
auto err = cudaPeekAtLastError();
|
||||
ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n";
|
||||
} else if (code == ncclSystemError) {
|
||||
ss << " This might be caused by a network configuration issue. Please consider specifying "
|
||||
"the network interface for NCCL via environment variables listed in its reference: "
|
||||
"`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n";
|
||||
}
|
||||
return Fail(ss.str());
|
||||
}
|
||||
|
||||
auto GetNCCLType(ArrayInterfaceHandler::Type type) {
|
||||
auto fatal = [] {
|
||||
LOG(FATAL) << "Invalid type for NCCL operation.";
|
||||
@@ -98,11 +79,12 @@ void RunBitwiseAllreduce(dh::CUDAStreamView stream, common::Span<std::int8_t> ou
|
||||
common::Span<std::int8_t> data, Op op) {
|
||||
dh::device_vector<std::int8_t> buffer(data.size() * pcomm->World());
|
||||
auto* device_buffer = buffer.data().get();
|
||||
auto stub = pcomm->Stub();
|
||||
|
||||
// First gather data from all the workers.
|
||||
CHECK(handle);
|
||||
auto rc = GetNCCLResult(
|
||||
ncclAllGather(data.data(), device_buffer, data.size(), ncclInt8, handle, pcomm->Stream()));
|
||||
auto rc =
|
||||
stub->Allgather(data.data(), device_buffer, data.size(), ncclInt8, handle, pcomm->Stream());
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
@@ -153,6 +135,8 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
|
||||
}
|
||||
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||
CHECK(nccl);
|
||||
auto stub = nccl->Stub();
|
||||
|
||||
return Success() << [&] {
|
||||
if (IsBitwiseOp(op)) {
|
||||
return BitwiseAllReduce(nccl, nccl->Handle(), data, op);
|
||||
@@ -160,9 +144,8 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
|
||||
return DispatchDType(type, [=](auto t) {
|
||||
using T = decltype(t);
|
||||
auto rdata = common::RestoreType<T>(data);
|
||||
auto rc = ncclAllReduce(data.data(), data.data(), rdata.size(), GetNCCLType(type),
|
||||
GetNCCLRedOp(op), nccl->Handle(), nccl->Stream());
|
||||
return GetNCCLResult(rc);
|
||||
return stub->Allreduce(data.data(), data.data(), rdata.size(), GetNCCLType(type),
|
||||
GetNCCLRedOp(op), nccl->Handle(), nccl->Stream());
|
||||
});
|
||||
}
|
||||
} << [&] { return nccl->Block(); };
|
||||
@@ -175,9 +158,11 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
|
||||
}
|
||||
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||
CHECK(nccl);
|
||||
auto stub = nccl->Stub();
|
||||
|
||||
return Success() << [&] {
|
||||
return GetNCCLResult(ncclBroadcast(data.data(), data.data(), data.size_bytes(), ncclInt8, root,
|
||||
nccl->Handle(), nccl->Stream()));
|
||||
return stub->Broadcast(data.data(), data.data(), data.size_bytes(), ncclInt8, root,
|
||||
nccl->Handle(), nccl->Stream());
|
||||
} << [&] { return nccl->Block(); };
|
||||
}
|
||||
|
||||
@@ -188,10 +173,12 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
|
||||
}
|
||||
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||
CHECK(nccl);
|
||||
auto stub = nccl->Stub();
|
||||
|
||||
auto send = data.subspan(comm.Rank() * size, size);
|
||||
return Success() << [&] {
|
||||
return GetNCCLResult(
|
||||
ncclAllGather(send.data(), data.data(), size, ncclInt8, nccl->Handle(), nccl->Stream()));
|
||||
return stub->Allgather(send.data(), data.data(), size, ncclInt8, nccl->Handle(),
|
||||
nccl->Stream());
|
||||
} << [&] { return nccl->Block(); };
|
||||
}
|
||||
|
||||
@@ -203,19 +190,20 @@ namespace cuda_impl {
|
||||
*/
|
||||
Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes, common::Span<std::int8_t> recv) {
|
||||
return Success() << [] { return GetNCCLResult(ncclGroupStart()); } << [&] {
|
||||
auto stub = comm->Stub();
|
||||
return Success() << [&stub] { return stub->GroupStart(); } << [&] {
|
||||
std::size_t offset = 0;
|
||||
for (std::int32_t r = 0; r < comm->World(); ++r) {
|
||||
auto as_bytes = sizes[r];
|
||||
auto rc = ncclBroadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes,
|
||||
ncclInt8, r, comm->Handle(), dh::DefaultStream());
|
||||
if (rc != ncclSuccess) {
|
||||
return GetNCCLResult(rc);
|
||||
auto rc = stub->Broadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes,
|
||||
ncclInt8, r, comm->Handle(), dh::DefaultStream());
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
offset += as_bytes;
|
||||
}
|
||||
return Success();
|
||||
} << [] { return GetNCCLResult(ncclGroupEnd()); };
|
||||
} << [&] { return stub->GroupEnd(); };
|
||||
}
|
||||
} // namespace cuda_impl
|
||||
|
||||
@@ -228,10 +216,11 @@ Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const>
|
||||
if (!comm.IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
auto stub = nccl->Stub();
|
||||
|
||||
switch (algo) {
|
||||
case AllgatherVAlgo::kRing: {
|
||||
return Success() << [] { return GetNCCLResult(ncclGroupStart()); } << [&] {
|
||||
return Success() << [&] { return stub->GroupStart(); } << [&] {
|
||||
// get worker offset
|
||||
detail::AllgatherVOffset(sizes, recv_segments);
|
||||
// copy data
|
||||
@@ -241,8 +230,8 @@ Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const>
|
||||
cudaMemcpyDeviceToDevice, nccl->Stream()));
|
||||
}
|
||||
return detail::RingAllgatherV(comm, sizes, recv_segments, recv);
|
||||
} << [] {
|
||||
return GetNCCLResult(ncclGroupEnd());
|
||||
} << [&] {
|
||||
return stub->GroupEnd();
|
||||
} << [&] { return nccl->Block(); };
|
||||
}
|
||||
case AllgatherVAlgo::kBcast: {
|
||||
|
||||
@@ -8,7 +8,8 @@
|
||||
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
||||
#include "coll.h" // for Coll
|
||||
#include "comm.h" // for Comm
|
||||
#include "xgboost/span.h" // for Span
|
||||
#include "nccl_stub.h"
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
class NCCLColl : public Coll {
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include <algorithm> // for copy
|
||||
#include <chrono> // for seconds
|
||||
#include <cstdlib> // for exit
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string> // for string
|
||||
#include <utility> // for move, forward
|
||||
@@ -29,19 +30,28 @@ Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds time
|
||||
Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
|
||||
std::string const& task_id, TCPSocket* out, std::int32_t rank,
|
||||
std::int32_t world) {
|
||||
// get information from tracker
|
||||
// Get information from the tracker
|
||||
CHECK(!info.host.empty());
|
||||
auto rc = Connect(info.host, info.port, retry, timeout, out);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to connect to the tracker.", std::move(rc));
|
||||
}
|
||||
|
||||
TCPSocket& tracker = *out;
|
||||
return std::move(rc)
|
||||
<< [&] { return tracker.NonBlocking(false); }
|
||||
<< [&] { return tracker.RecvTimeout(timeout); }
|
||||
<< [&] { return proto::Magic{}.Verify(&tracker); }
|
||||
<< [&] { return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id); };
|
||||
return Success() << [&] {
|
||||
auto rc = Connect(info.host, info.port, retry, timeout, out);
|
||||
if (rc.OK()) {
|
||||
return rc;
|
||||
} else {
|
||||
return Fail("Failed to connect to the tracker.", std::move(rc));
|
||||
}
|
||||
} << [&] {
|
||||
return tracker.NonBlocking(false);
|
||||
} << [&] {
|
||||
return tracker.RecvTimeout(timeout);
|
||||
} << [&] {
|
||||
return proto::Magic{}.Verify(&tracker);
|
||||
} << [&] {
|
||||
return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id);
|
||||
} << [&] {
|
||||
LOG(INFO) << "Task " << task_id << " connected to the tracker";
|
||||
return Success();
|
||||
};
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Comm::ConnectTracker(TCPSocket* out) const {
|
||||
@@ -49,14 +59,6 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
|
||||
this->Rank(), this->World());
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_NCCL) && !defined(XGBOOST_USE_RCCL)
|
||||
Comm* Comm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
||||
common::AssertGPUSupport();
|
||||
common::AssertNCCLSupport();
|
||||
return nullptr;
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_NCCL)
|
||||
|
||||
[[nodiscard]] Result ConnectWorkers(Comm const& comm, TCPSocket* listener, std::int32_t lport,
|
||||
proto::PeerInfo ninfo, std::chrono::seconds timeout,
|
||||
std::int32_t retry,
|
||||
@@ -181,12 +183,21 @@ Comm* Comm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
||||
}
|
||||
|
||||
RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t retry, std::string task_id)
|
||||
: Comm{std::move(host), port, timeout, retry, std::move(task_id)} {
|
||||
std::int32_t retry, std::string task_id, StringView nccl_path)
|
||||
: HostComm{std::move(host), port, timeout, retry, std::move(task_id)},
|
||||
nccl_path_{std::move(nccl_path)} {
|
||||
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_NCCL) && !defined(XGBOOST_USE_RCCL)
|
||||
Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
||||
common::AssertGPUSupport();
|
||||
common::AssertNCCLSupport();
|
||||
return nullptr;
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_NCCL)
|
||||
|
||||
[[nodiscard]] Result RabitComm::Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
|
||||
std::string task_id) {
|
||||
TCPSocket tracker;
|
||||
@@ -209,24 +220,18 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
|
||||
std::shared_ptr<TCPSocket> error_sock{TCPSocket::CreatePtr(domain)};
|
||||
auto eport = error_sock->BindHost();
|
||||
error_sock->Listen();
|
||||
error_worker_ = std::thread{[this, error_sock = std::move(error_sock)] {
|
||||
error_worker_ = std::thread{[error_sock = std::move(error_sock)] {
|
||||
auto conn = error_sock->Accept();
|
||||
// On Windows accept returns an invalid socket after network is shutdown.
|
||||
// On Windows, accept returns a closed socket after finalize.
|
||||
if (conn.IsClosed()) {
|
||||
return;
|
||||
}
|
||||
LOG(WARNING) << "Another worker is running into error.";
|
||||
std::string scmd;
|
||||
conn.Recv(&scmd);
|
||||
auto jcmd = Json::Load(scmd);
|
||||
auto rc = this->Shutdown();
|
||||
if (!rc.OK()) {
|
||||
LOG(WARNING) << "Fail to shutdown worker:" << rc.Report();
|
||||
}
|
||||
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
|
||||
exit(-1);
|
||||
// exit is nicer than abort as the former performs cleanups.
|
||||
std::exit(-1);
|
||||
#else
|
||||
LOG(FATAL) << rc.Report();
|
||||
LOG(FATAL) << "abort";
|
||||
#endif
|
||||
}};
|
||||
error_worker_.detach();
|
||||
@@ -259,8 +264,8 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
|
||||
CHECK(this->channels_.empty());
|
||||
for (auto& w : workers) {
|
||||
if (w) {
|
||||
w->SetNoDelay();
|
||||
rc = w->NonBlocking(true);
|
||||
rc = std::move(rc) << [&] { return w->SetNoDelay(); } << [&] { return w->NonBlocking(true); }
|
||||
<< [&] { return w->SetKeepAlive(); };
|
||||
}
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
|
||||
@@ -10,21 +10,24 @@
|
||||
#include <sstream> // for stringstream
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/cuda_context.cuh" // for CUDAContext
|
||||
#include "../common/device_helpers.cuh" // for DefaultStream
|
||||
#include "../common/type.h" // for EraseType
|
||||
#include "broadcast.h" // for Broadcast
|
||||
#include "comm.cuh" // for NCCLComm
|
||||
#include "comm.h" // for Comm
|
||||
#include "nccl_stub.h" // for NcclStub
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
Result GetUniqueId(Comm const& comm, std::shared_ptr<Coll> coll, ncclUniqueId* pid) {
|
||||
Result GetUniqueId(Comm const& comm, std::shared_ptr<NcclStub> stub, std::shared_ptr<Coll> coll,
|
||||
ncclUniqueId* pid) {
|
||||
static const int kRootRank = 0;
|
||||
ncclUniqueId id;
|
||||
if (comm.Rank() == kRootRank) {
|
||||
dh::safe_nccl(ncclGetUniqueId(&id));
|
||||
auto rc = stub->GetUniqueId(&id);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
auto rc = coll->Broadcast(
|
||||
comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)}, kRootRank);
|
||||
@@ -63,14 +66,15 @@ static std::string PrintUUID(xgboost::common::Span<std::uint64_t, kUuidLength> c
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const {
|
||||
return new NCCLComm{ctx, *this, pimpl};
|
||||
Comm* RabitComm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const {
|
||||
return new NCCLComm{ctx, *this, pimpl, StringView{this->nccl_path_}};
|
||||
}
|
||||
|
||||
NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl)
|
||||
NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl,
|
||||
StringView nccl_path)
|
||||
: Comm{root.TrackerInfo().host, root.TrackerInfo().port, root.Timeout(), root.Retry(),
|
||||
root.TaskID()},
|
||||
stream_{dh::DefaultStream()} {
|
||||
stream_{ctx->CUDACtx()->Stream()} {
|
||||
this->world_ = root.World();
|
||||
this->rank_ = root.Rank();
|
||||
this->domain_ = root.Domain();
|
||||
@@ -79,6 +83,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
|
||||
stub_ = std::make_shared<NcclStub>(nccl_path);
|
||||
|
||||
std::vector<std::uint64_t> uuids(root.World() * kUuidLength, 0);
|
||||
auto s_uuid = xgboost::common::Span<std::uint64_t>{uuids.data(), uuids.size()};
|
||||
@@ -104,19 +109,22 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
||||
<< "Multiple processes within communication group running on same CUDA "
|
||||
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||
|
||||
rc = GetUniqueId(root, pimpl, &nccl_unique_id_);
|
||||
rc = std::move(rc) << [&] { return GetUniqueId(root, this->stub_, pimpl, &nccl_unique_id_); } <<
|
||||
[&] {
|
||||
return this->stub_->CommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank());
|
||||
};
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank()));
|
||||
|
||||
for (std::int32_t r = 0; r < root.World(); ++r) {
|
||||
this->channels_.emplace_back(
|
||||
std::make_shared<NCCLChannel>(root, r, nccl_comm_, dh::DefaultStream()));
|
||||
std::make_shared<NCCLChannel>(root, r, nccl_comm_, stub_, dh::DefaultStream()));
|
||||
}
|
||||
}
|
||||
|
||||
NCCLComm::~NCCLComm() {
|
||||
if (nccl_comm_) {
|
||||
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
|
||||
auto rc = stub_->CommDestroy(nccl_comm_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -9,9 +9,13 @@
|
||||
#include "../common/cuda_to_hip.h"
|
||||
#include "rccl.h"
|
||||
#endif // XGBOOST_USE_NCCL
|
||||
|
||||
#include <utility> // for move
|
||||
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "coll.h"
|
||||
#include "comm.h"
|
||||
#include "nccl_stub.h" // for NcclStub
|
||||
#include "xgboost/context.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
@@ -24,15 +28,20 @@ inline Result GetCUDAResult(cudaError rc) {
|
||||
return Fail(msg);
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
|
||||
class NCCLComm : public Comm {
|
||||
ncclComm_t nccl_comm_{nullptr};
|
||||
std::shared_ptr<NcclStub> stub_;
|
||||
ncclUniqueId nccl_unique_id_{};
|
||||
dh::CUDAStreamView stream_;
|
||||
std::string nccl_path_;
|
||||
|
||||
public:
|
||||
[[nodiscard]] ncclComm_t Handle() const { return nccl_comm_; }
|
||||
auto Stub() const { return stub_; }
|
||||
|
||||
explicit NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl);
|
||||
explicit NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl,
|
||||
StringView nccl_path);
|
||||
[[nodiscard]] Result LogTracker(std::string) const override {
|
||||
LOG(FATAL) << "Device comm is used for logging.";
|
||||
return Fail("Undefined.");
|
||||
@@ -49,22 +58,29 @@ class NCCLComm : public Comm {
|
||||
class NCCLChannel : public Channel {
|
||||
std::int32_t rank_{-1};
|
||||
ncclComm_t nccl_comm_{};
|
||||
std::shared_ptr<NcclStub> stub_;
|
||||
dh::CUDAStreamView stream_;
|
||||
|
||||
public:
|
||||
explicit NCCLChannel(Comm const& comm, std::int32_t rank, ncclComm_t nccl_comm,
|
||||
dh::CUDAStreamView stream)
|
||||
: rank_{rank}, nccl_comm_{nccl_comm}, Channel{comm, nullptr}, stream_{stream} {}
|
||||
std::shared_ptr<NcclStub> stub, dh::CUDAStreamView stream)
|
||||
: rank_{rank},
|
||||
nccl_comm_{nccl_comm},
|
||||
stub_{std::move(stub)},
|
||||
Channel{comm, nullptr},
|
||||
stream_{stream} {}
|
||||
|
||||
void SendAll(std::int8_t const* ptr, std::size_t n) override {
|
||||
dh::safe_nccl(ncclSend(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
|
||||
[[nodiscard]] Result SendAll(std::int8_t const* ptr, std::size_t n) override {
|
||||
return stub_->Send(ptr, n, ncclInt8, rank_, nccl_comm_, stream_);
|
||||
}
|
||||
void RecvAll(std::int8_t* ptr, std::size_t n) override {
|
||||
dh::safe_nccl(ncclRecv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
|
||||
[[nodiscard]] Result RecvAll(std::int8_t* ptr, std::size_t n) override {
|
||||
return stub_->Recv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_);
|
||||
}
|
||||
[[nodiscard]] Result Block() override {
|
||||
auto rc = stream_.Sync(false);
|
||||
return GetCUDAResult(rc);
|
||||
}
|
||||
};
|
||||
|
||||
#endif // defined(XGBOOST_USE_NCCL)
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -34,6 +34,8 @@ inline std::int32_t BootstrapPrev(std::int32_t r, std::int32_t world) {
|
||||
return nrank;
|
||||
}
|
||||
|
||||
inline StringView DefaultNcclName() { return "libnccl.so.2"; }
|
||||
|
||||
class Channel;
|
||||
class Coll;
|
||||
|
||||
@@ -86,11 +88,21 @@ class Comm : public std::enable_shared_from_this<Comm> {
|
||||
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;
|
||||
|
||||
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
|
||||
|
||||
virtual Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const;
|
||||
};
|
||||
|
||||
class RabitComm : public Comm {
|
||||
/**
|
||||
* @brief Base class for CPU-based communicator.
|
||||
*/
|
||||
class HostComm : public Comm {
|
||||
public:
|
||||
using Comm::Comm;
|
||||
[[nodiscard]] virtual Comm* MakeCUDAVar(Context const* ctx,
|
||||
std::shared_ptr<Coll> pimpl) const = 0;
|
||||
};
|
||||
|
||||
class RabitComm : public HostComm {
|
||||
std::string nccl_path_ = std::string{DefaultNcclName()};
|
||||
|
||||
[[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
|
||||
std::string task_id);
|
||||
[[nodiscard]] Result Shutdown();
|
||||
@@ -100,13 +112,15 @@ class RabitComm : public Comm {
|
||||
RabitComm() = default;
|
||||
// ctor for testing where environment is known.
|
||||
RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t retry, std::string task_id);
|
||||
std::int32_t retry, std::string task_id, StringView nccl_path);
|
||||
~RabitComm() noexcept(false) override;
|
||||
|
||||
[[nodiscard]] bool IsFederated() const override { return false; }
|
||||
[[nodiscard]] Result LogTracker(std::string msg) const override;
|
||||
|
||||
[[nodiscard]] Result SignalError(Result const&) override;
|
||||
|
||||
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -121,21 +135,25 @@ class Channel {
|
||||
explicit Channel(Comm const& comm, std::shared_ptr<TCPSocket> sock)
|
||||
: sock_{std::move(sock)}, comm_{comm} {}
|
||||
|
||||
virtual void SendAll(std::int8_t const* ptr, std::size_t n) {
|
||||
[[nodiscard]] virtual Result SendAll(std::int8_t const* ptr, std::size_t n) {
|
||||
Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast<std::int8_t*>(ptr), n, sock_.get(), 0};
|
||||
CHECK(sock_.get());
|
||||
comm_.Submit(std::move(op));
|
||||
return Success();
|
||||
}
|
||||
void SendAll(common::Span<std::int8_t const> data) {
|
||||
this->SendAll(data.data(), data.size_bytes());
|
||||
[[nodiscard]] Result SendAll(common::Span<std::int8_t const> data) {
|
||||
return this->SendAll(data.data(), data.size_bytes());
|
||||
}
|
||||
|
||||
virtual void RecvAll(std::int8_t* ptr, std::size_t n) {
|
||||
[[nodiscard]] virtual Result RecvAll(std::int8_t* ptr, std::size_t n) {
|
||||
Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0};
|
||||
CHECK(sock_.get());
|
||||
comm_.Submit(std::move(op));
|
||||
return Success();
|
||||
}
|
||||
[[nodiscard]] Result RecvAll(common::Span<std::int8_t> data) {
|
||||
return this->RecvAll(data.data(), data.size_bytes());
|
||||
}
|
||||
void RecvAll(common::Span<std::int8_t> data) { this->RecvAll(data.data(), data.size_bytes()); }
|
||||
|
||||
[[nodiscard]] auto Socket() const { return sock_; }
|
||||
[[nodiscard]] virtual Result Block() { return comm_.Block(); }
|
||||
|
||||
122
src/collective/comm_group.cc
Normal file
122
src/collective/comm_group.cc
Normal file
@@ -0,0 +1,122 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include "comm_group.h"
|
||||
|
||||
#include <algorithm> // for transform
|
||||
#include <chrono> // for seconds
|
||||
#include <cstdint> // for int32_t
|
||||
#include <memory> // for shared_ptr, unique_ptr
|
||||
#include <string> // for string
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/json_utils.h" // for OptionalArg
|
||||
#include "coll.h" // for Coll
|
||||
#include "comm.h" // for Comm
|
||||
#include "tracker.h" // for GetHostAddress
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/context.h" // for DeviceOrd
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
#include "../../plugin/federated/federated_coll.h"
|
||||
#include "../../plugin/federated/federated_comm.h"
|
||||
#endif
|
||||
|
||||
namespace xgboost::collective {
|
||||
[[nodiscard]] std::shared_ptr<Coll> CommGroup::Backend(DeviceOrd device) const {
|
||||
if (device.IsCUDA()) {
|
||||
if (!gpu_coll_) {
|
||||
gpu_coll_.reset(backend_->MakeCUDAVar());
|
||||
}
|
||||
return gpu_coll_;
|
||||
}
|
||||
return backend_;
|
||||
}
|
||||
|
||||
[[nodiscard]] Comm const& CommGroup::Ctx(Context const* ctx, DeviceOrd device) const {
|
||||
if (device.IsCUDA()) {
|
||||
CHECK(ctx->IsCUDA());
|
||||
if (!gpu_comm_ || gpu_comm_->World() != comm_->World()) {
|
||||
gpu_comm_.reset(comm_->MakeCUDAVar(ctx, backend_));
|
||||
}
|
||||
return *gpu_comm_;
|
||||
}
|
||||
return *comm_;
|
||||
}
|
||||
|
||||
CommGroup::CommGroup()
|
||||
: comm_{std::shared_ptr<RabitComm>(new RabitComm{})}, // NOLINT
|
||||
backend_{std::shared_ptr<Coll>(new Coll{})} {} // NOLINT
|
||||
|
||||
[[nodiscard]] CommGroup* CommGroup::Create(Json config) {
|
||||
if (IsA<Null>(config)) {
|
||||
return new CommGroup;
|
||||
}
|
||||
|
||||
std::string type = OptionalArg<String>(config, "dmlc_communicator", std::string{"rabit"});
|
||||
// Try both lower and upper case for compatibility
|
||||
auto get_param = [&](std::string name, auto dft, auto t) {
|
||||
std::string upper;
|
||||
std::transform(name.cbegin(), name.cend(), std::back_inserter(upper),
|
||||
[](char c) { return std::toupper(c); });
|
||||
std::transform(name.cbegin(), name.cend(), name.begin(),
|
||||
[](char c) { return std::tolower(c); });
|
||||
|
||||
auto const& obj = get<Object const>(config);
|
||||
auto it = obj.find(upper);
|
||||
if (it != obj.cend()) {
|
||||
return OptionalArg<decltype(t)>(config, upper, dft);
|
||||
} else {
|
||||
return OptionalArg<decltype(t)>(config, name, dft);
|
||||
}
|
||||
};
|
||||
// Common args
|
||||
auto retry = get_param("dmlc_retry", static_cast<Integer::Int>(DefaultRetry()), Integer{});
|
||||
auto timeout =
|
||||
get_param("dmlc_timeout_sec", static_cast<Integer::Int>(DefaultTimeoutSec()), Integer{});
|
||||
auto task_id = get_param("dmlc_task_id", std::string{}, String{});
|
||||
|
||||
if (type == "rabit") {
|
||||
auto host = get_param("dmlc_tracker_uri", std::string{}, String{});
|
||||
auto port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{});
|
||||
auto nccl = get_param("dmlc_nccl_path", std::string{DefaultNcclName()}, String{});
|
||||
auto ptr =
|
||||
new CommGroup{std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT
|
||||
host, static_cast<std::int32_t>(port), std::chrono::seconds{timeout},
|
||||
static_cast<std::int32_t>(retry), task_id, nccl}},
|
||||
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
|
||||
return ptr;
|
||||
} else if (type == "federated") {
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
auto ptr = new CommGroup{
|
||||
std::make_shared<FederatedComm>(retry, std::chrono::seconds{timeout}, task_id, config),
|
||||
std::make_shared<FederatedColl>()};
|
||||
return ptr;
|
||||
#endif // defined(XGBOOST_USE_FEDERATED)
|
||||
} else {
|
||||
LOG(FATAL) << "Invalid communicator type";
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<collective::CommGroup>& GlobalCommGroup() {
|
||||
static thread_local std::unique_ptr<collective::CommGroup> sptr;
|
||||
if (!sptr) {
|
||||
Json config{Null{}};
|
||||
sptr.reset(CommGroup::Create(config));
|
||||
}
|
||||
return sptr;
|
||||
}
|
||||
|
||||
void GlobalCommGroupInit(Json config) {
|
||||
auto& sptr = GlobalCommGroup();
|
||||
sptr.reset(CommGroup::Create(std::move(config)));
|
||||
}
|
||||
|
||||
void GlobalCommGroupFinalize() {
|
||||
auto& sptr = GlobalCommGroup();
|
||||
sptr.reset();
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
55
src/collective/comm_group.h
Normal file
55
src/collective/comm_group.h
Normal file
@@ -0,0 +1,55 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <memory> // for shared_ptr, unique_ptr
|
||||
#include <string> // for string
|
||||
#include <utility> // for move
|
||||
|
||||
#include "coll.h" // for Comm
|
||||
#include "comm.h" // for Coll
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/collective/socket.h" // for GetHostName
|
||||
|
||||
namespace xgboost::collective {
|
||||
/**
|
||||
* @brief Communicator group used for double dispatching between communicators and
|
||||
* collective implementations.
|
||||
*/
|
||||
class CommGroup {
|
||||
std::shared_ptr<HostComm> comm_;
|
||||
mutable std::shared_ptr<Comm> gpu_comm_;
|
||||
|
||||
std::shared_ptr<Coll> backend_;
|
||||
mutable std::shared_ptr<Coll> gpu_coll_; // lazy initialization
|
||||
|
||||
CommGroup(std::shared_ptr<Comm> comm, std::shared_ptr<Coll> coll)
|
||||
: comm_{std::dynamic_pointer_cast<HostComm>(comm)}, backend_{std::move(coll)} {
|
||||
CHECK(comm_);
|
||||
}
|
||||
|
||||
public:
|
||||
CommGroup();
|
||||
|
||||
[[nodiscard]] auto World() const { return comm_->World(); }
|
||||
[[nodiscard]] auto Rank() const { return comm_->Rank(); }
|
||||
[[nodiscard]] bool IsDistributed() const { return comm_->IsDistributed(); }
|
||||
|
||||
[[nodiscard]] static CommGroup* Create(Json config);
|
||||
|
||||
[[nodiscard]] std::shared_ptr<Coll> Backend(DeviceOrd device) const;
|
||||
[[nodiscard]] Comm const& Ctx(Context const* ctx, DeviceOrd device) const;
|
||||
[[nodiscard]] Result SignalError(Result const& res) { return comm_->SignalError(res); }
|
||||
|
||||
[[nodiscard]] Result ProcessorName(std::string* out) const {
|
||||
auto rc = GetHostName(out);
|
||||
return rc;
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<collective::CommGroup>& GlobalCommGroup();
|
||||
|
||||
void GlobalCommGroupInit(Json config);
|
||||
|
||||
void GlobalCommGroupFinalize();
|
||||
} // namespace xgboost::collective
|
||||
@@ -3,6 +3,7 @@
|
||||
*/
|
||||
#include "communicator.h"
|
||||
|
||||
#include "comm.h"
|
||||
#include "in_memory_communicator.h"
|
||||
#include "noop_communicator.h"
|
||||
#include "rabit_communicator.h"
|
||||
@@ -14,8 +15,12 @@
|
||||
namespace xgboost::collective {
|
||||
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
|
||||
thread_local CommunicatorType Communicator::type_{};
|
||||
thread_local std::string Communicator::nccl_path_{};
|
||||
|
||||
void Communicator::Init(Json const& config) {
|
||||
auto nccl = OptionalArg<String>(config, "dmlc_nccl_path", std::string{DefaultNcclName()});
|
||||
nccl_path_ = nccl;
|
||||
|
||||
auto type = GetTypeFromEnv();
|
||||
auto const arg = GetTypeFromConfig(config);
|
||||
if (arg != CommunicatorType::kUnknown) {
|
||||
|
||||
@@ -31,17 +31,17 @@ DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
|
||||
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
|
||||
switch (type_) {
|
||||
case CommunicatorType::kRabit:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_));
|
||||
break;
|
||||
case CommunicatorType::kFederated:
|
||||
case CommunicatorType::kInMemory:
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
||||
break;
|
||||
case CommunicatorType::kInMemoryNccl:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true));
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true, nccl_path_));
|
||||
break;
|
||||
default:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_));
|
||||
}
|
||||
#else
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
||||
|
||||
@@ -234,6 +234,7 @@ class Communicator {
|
||||
|
||||
static thread_local std::unique_ptr<Communicator> communicator_;
|
||||
static thread_local CommunicatorType type_;
|
||||
static thread_local std::string nccl_path_;
|
||||
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
|
||||
static thread_local std::unique_ptr<DeviceCommunicator> device_communicator_;
|
||||
#endif
|
||||
|
||||
@@ -10,21 +10,26 @@
|
||||
#include "xgboost/logging.h" // for CHECK
|
||||
|
||||
namespace xgboost::collective {
|
||||
Result Loop::EmptyQueue() {
|
||||
Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||
timer_.Start(__func__);
|
||||
auto error = [this] {
|
||||
this->stop_ = true;
|
||||
auto error = [this] { timer_.Stop(__func__); };
|
||||
|
||||
if (stop_) {
|
||||
timer_.Stop(__func__);
|
||||
};
|
||||
return Success();
|
||||
}
|
||||
|
||||
while (!queue_.empty() && !stop_) {
|
||||
std::queue<Op> qcopy;
|
||||
auto& qcopy = *p_queue;
|
||||
|
||||
// clear the copied queue
|
||||
while (!qcopy.empty()) {
|
||||
rabit::utils::PollHelper poll;
|
||||
std::size_t n_ops = qcopy.size();
|
||||
|
||||
// watch all ops
|
||||
while (!queue_.empty()) {
|
||||
auto op = queue_.front();
|
||||
queue_.pop();
|
||||
// Iterate through all the ops for poll
|
||||
for (std::size_t i = 0; i < n_ops; ++i) {
|
||||
auto op = qcopy.front();
|
||||
qcopy.pop();
|
||||
|
||||
switch (op.code) {
|
||||
case Op::kRead: {
|
||||
@@ -40,6 +45,7 @@ Result Loop::EmptyQueue() {
|
||||
return Fail("Invalid socket operation.");
|
||||
}
|
||||
}
|
||||
|
||||
qcopy.push(op);
|
||||
}
|
||||
|
||||
@@ -51,10 +57,12 @@ Result Loop::EmptyQueue() {
|
||||
error();
|
||||
return rc;
|
||||
}
|
||||
|
||||
// we wonldn't be here if the queue is empty.
|
||||
CHECK(!qcopy.empty());
|
||||
|
||||
while (!qcopy.empty() && !stop_) {
|
||||
// Iterate through all the ops for performing the operations
|
||||
for (std::size_t i = 0; i < n_ops; ++i) {
|
||||
auto op = qcopy.front();
|
||||
qcopy.pop();
|
||||
|
||||
@@ -81,20 +89,21 @@ Result Loop::EmptyQueue() {
|
||||
}
|
||||
|
||||
if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) {
|
||||
stop_ = true;
|
||||
auto rc = system::FailWithCode("Invalid socket output.");
|
||||
error();
|
||||
return rc;
|
||||
}
|
||||
|
||||
op.off += n_bytes_done;
|
||||
CHECK_LE(op.off, op.n);
|
||||
|
||||
if (op.off != op.n) {
|
||||
// not yet finished, push back to queue for next round.
|
||||
queue_.push(op);
|
||||
qcopy.push(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
timer_.Stop(__func__);
|
||||
return Success();
|
||||
}
|
||||
@@ -107,22 +116,46 @@ void Loop::Process() {
|
||||
if (stop_) {
|
||||
break;
|
||||
}
|
||||
CHECK(!mu_.try_lock());
|
||||
|
||||
this->rc_ = this->EmptyQueue();
|
||||
if (!rc_.OK()) {
|
||||
stop_ = true;
|
||||
auto unlock_notify = [&](bool is_blocking, bool stop) {
|
||||
if (!is_blocking) {
|
||||
std::lock_guard guard{mu_};
|
||||
stop_ = stop;
|
||||
} else {
|
||||
stop_ = stop;
|
||||
lock.unlock();
|
||||
}
|
||||
cv_.notify_one();
|
||||
break;
|
||||
};
|
||||
|
||||
// move the queue
|
||||
std::queue<Op> qcopy;
|
||||
bool is_blocking = false;
|
||||
while (!queue_.empty()) {
|
||||
auto op = queue_.front();
|
||||
queue_.pop();
|
||||
if (op.code == Op::kBlock) {
|
||||
is_blocking = true;
|
||||
} else {
|
||||
qcopy.push(op);
|
||||
}
|
||||
}
|
||||
// unblock the queue
|
||||
if (!is_blocking) {
|
||||
lock.unlock();
|
||||
}
|
||||
// clear the queue
|
||||
auto rc = this->EmptyQueue(&qcopy);
|
||||
// Handle error
|
||||
if (!rc.OK()) {
|
||||
unlock_notify(is_blocking, true);
|
||||
std::lock_guard<std::mutex> guard{rc_lock_};
|
||||
this->rc_ = std::move(rc);
|
||||
return;
|
||||
}
|
||||
|
||||
CHECK(queue_.empty());
|
||||
CHECK(!mu_.try_lock());
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
if (rc_.OK()) {
|
||||
CHECK(queue_.empty());
|
||||
CHECK(qcopy.empty());
|
||||
unlock_notify(is_blocking, false);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,6 +173,24 @@ Result Loop::Stop() {
|
||||
return Success();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Loop::Block() {
|
||||
{
|
||||
std::lock_guard<std::mutex> guard{rc_lock_};
|
||||
if (!rc_.OK()) {
|
||||
return std::move(rc_);
|
||||
}
|
||||
}
|
||||
this->Submit(Op{Op::kBlock});
|
||||
{
|
||||
std::unique_lock lock{mu_};
|
||||
cv_.wait(lock, [this] { return (this->queue_.empty()) || stop_; });
|
||||
}
|
||||
{
|
||||
std::lock_guard<std::mutex> lock{rc_lock_};
|
||||
return std::move(rc_);
|
||||
}
|
||||
}
|
||||
|
||||
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
|
||||
timer_.Init(__func__);
|
||||
worker_ = std::thread{[this] {
|
||||
|
||||
@@ -20,13 +20,14 @@ namespace xgboost::collective {
|
||||
class Loop {
|
||||
public:
|
||||
struct Op {
|
||||
enum Code : std::int8_t { kRead = 0, kWrite = 1 } code;
|
||||
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2 } code;
|
||||
std::int32_t rank{-1};
|
||||
std::int8_t* ptr{nullptr};
|
||||
std::size_t n{0};
|
||||
TCPSocket* sock{nullptr};
|
||||
std::size_t off{0};
|
||||
|
||||
explicit Op(Code c) : code{c} { CHECK(c == kBlock); }
|
||||
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
|
||||
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
|
||||
Op(Op const&) = default;
|
||||
@@ -41,12 +42,15 @@ class Loop {
|
||||
std::mutex mu_;
|
||||
std::queue<Op> queue_;
|
||||
std::chrono::seconds timeout_;
|
||||
|
||||
Result rc_;
|
||||
std::mutex rc_lock_; // lock for transferring error info.
|
||||
|
||||
bool stop_{false};
|
||||
std::exception_ptr curr_exce_{nullptr};
|
||||
common::Monitor timer_;
|
||||
common::Monitor mutable timer_;
|
||||
|
||||
Result EmptyQueue();
|
||||
Result EmptyQueue(std::queue<Op>* p_queue) const;
|
||||
void Process();
|
||||
|
||||
public:
|
||||
@@ -60,15 +64,7 @@ class Loop {
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Block() {
|
||||
{
|
||||
std::unique_lock lock{mu_};
|
||||
cv_.notify_all();
|
||||
}
|
||||
std::unique_lock lock{mu_};
|
||||
cv_.wait(lock, [this] { return this->queue_.empty() || stop_; });
|
||||
return std::move(rc_);
|
||||
}
|
||||
[[nodiscard]] Result Block();
|
||||
|
||||
explicit Loop(std::chrono::seconds timeout);
|
||||
|
||||
|
||||
@@ -2,12 +2,14 @@
|
||||
* Copyright 2023 XGBoost contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
|
||||
#include "comm.cuh"
|
||||
#include "nccl_device_communicator.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync)
|
||||
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync,
|
||||
StringView nccl_path)
|
||||
: device_ordinal_{device_ordinal},
|
||||
needs_sync_{needs_sync},
|
||||
world_size_{GetWorldSize()},
|
||||
@@ -18,6 +20,7 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sy
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
stub_ = std::make_shared<NcclStub>(std::move(nccl_path));
|
||||
|
||||
std::vector<uint64_t> uuids(world_size_ * kUuidLength, 0);
|
||||
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
|
||||
@@ -43,7 +46,8 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sy
|
||||
|
||||
nccl_unique_id_ = GetUniqueId();
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
|
||||
auto rc = stub_->CommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
|
||||
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
|
||||
@@ -51,7 +55,8 @@ NcclDeviceCommunicator::~NcclDeviceCommunicator() {
|
||||
return;
|
||||
}
|
||||
if (nccl_comm_) {
|
||||
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
|
||||
auto rc = stub_->CommDestroy(nccl_comm_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||
LOG(CONSOLE) << "======== NCCL Statistics========";
|
||||
@@ -137,8 +142,9 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
|
||||
auto *device_buffer = buffer.data().get();
|
||||
|
||||
// First gather data from all the workers.
|
||||
dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
|
||||
nccl_comm_, dh::DefaultStream()));
|
||||
auto rc = stub_->Allgather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
|
||||
nccl_comm_, dh::DefaultStream());
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
if (needs_sync_) {
|
||||
dh::DefaultStream().Sync();
|
||||
}
|
||||
@@ -170,9 +176,10 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co
|
||||
if (IsBitwiseOp(op)) {
|
||||
BitwiseAllReduce(send_receive_buffer, count, data_type, op);
|
||||
} else {
|
||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
|
||||
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
|
||||
dh::DefaultStream()));
|
||||
auto rc = stub_->Allreduce(send_receive_buffer, send_receive_buffer, count,
|
||||
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
|
||||
dh::DefaultStream());
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
allreduce_bytes_ += count * GetTypeSize(data_type);
|
||||
allreduce_calls_ += 1;
|
||||
@@ -185,8 +192,9 @@ void NcclDeviceCommunicator::AllGather(void const *send_buffer, void *receive_bu
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclAllGather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_,
|
||||
dh::DefaultStream()));
|
||||
auto rc = stub_->Allgather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_,
|
||||
dh::DefaultStream());
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
|
||||
@@ -206,14 +214,18 @@ void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_b
|
||||
receive_buffer->resize(total_bytes);
|
||||
|
||||
size_t offset = 0;
|
||||
dh::safe_nccl(ncclGroupStart());
|
||||
for (int32_t i = 0; i < world_size_; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
||||
ncclChar, i, nccl_comm_, dh::DefaultStream()));
|
||||
offset += as_bytes;
|
||||
}
|
||||
dh::safe_nccl(ncclGroupEnd());
|
||||
auto rc = Success() << [&] { return stub_->GroupStart(); } << [&] {
|
||||
for (int32_t i = 0; i < world_size_; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
auto rc = stub_->Broadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
||||
ncclChar, i, nccl_comm_, dh::DefaultStream());
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
offset += as_bytes;
|
||||
}
|
||||
return Success();
|
||||
} << [&] { return stub_->GroupEnd(); };
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::Synchronize() {
|
||||
|
||||
@@ -4,8 +4,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "comm.cuh"
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
#include "nccl_stub.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
@@ -25,7 +27,7 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
* needed. The in-memory communicator is used in tests with multiple threads, each thread
|
||||
* representing a rank/worker, so the additional synchronization is needed to avoid deadlocks.
|
||||
*/
|
||||
explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync);
|
||||
explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync, StringView nccl_path);
|
||||
~NcclDeviceCommunicator() override;
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override;
|
||||
@@ -74,7 +76,8 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
static const int kRootRank = 0;
|
||||
ncclUniqueId id;
|
||||
if (rank_ == kRootRank) {
|
||||
dh::safe_nccl(ncclGetUniqueId(&id));
|
||||
auto rc = stub_->GetUniqueId(&id);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
|
||||
return id;
|
||||
@@ -88,6 +91,7 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
int const world_size_;
|
||||
int const rank_;
|
||||
ncclComm_t nccl_comm_{};
|
||||
std::shared_ptr<NcclStub> stub_;
|
||||
ncclUniqueId nccl_unique_id_{};
|
||||
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
|
||||
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
|
||||
|
||||
131
src/collective/nccl_stub.cc
Normal file
131
src/collective/nccl_stub.cc
Normal file
@@ -0,0 +1,131 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL) || (defined(XGBOOST_USE_RCCL) && 0)
|
||||
#include "nccl_stub.h"
|
||||
|
||||
#include <cuda.h> // for CUDA_VERSION
|
||||
#include <cuda_runtime_api.h> // for cudaPeekAtLastError
|
||||
#include <dlfcn.h> // for dlclose, dlsym, dlopen
|
||||
#include <nccl.h>
|
||||
#include <thrust/system/cuda/error.h> // for cuda_category
|
||||
#include <thrust/system_error.h> // for system_error
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <sstream> // for stringstream
|
||||
#include <string> // for string
|
||||
#include <utility> // for move
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
Result NcclStub::GetNcclResult(ncclResult_t code) const {
|
||||
if (code == ncclSuccess) {
|
||||
return Success();
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "NCCL failure: " << this->GetErrorString(code) << ".";
|
||||
if (code == ncclUnhandledCudaError) {
|
||||
// nccl usually preserves the last error so we can get more details.
|
||||
auto err = cudaPeekAtLastError();
|
||||
ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n";
|
||||
} else if (code == ncclSystemError) {
|
||||
ss << " This might be caused by a network configuration issue. Please consider specifying "
|
||||
"the network interface for NCCL via environment variables listed in its reference: "
|
||||
"`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n";
|
||||
}
|
||||
return Fail(ss.str());
|
||||
}
|
||||
|
||||
NcclStub::NcclStub(StringView path) : path_{std::move(path)} {
|
||||
#if defined(XGBOOST_USE_DLOPEN_NCCL) || defined(XGBOOST_USE_DLOPEN_RCCL)
|
||||
CHECK(!path_.empty()) << "Empty path for NCCL.";
|
||||
|
||||
auto cu_major = (CUDA_VERSION) / 1000;
|
||||
std::stringstream ss;
|
||||
ss << R"m(
|
||||
|
||||
If XGBoost is installed from PyPI with pip, the error can fixed by:
|
||||
|
||||
- Run `pip install nvidia-nccl-cu)m"
|
||||
<< cu_major << "` (Or with any CUDA version that's compatible with " << cu_major << ").";
|
||||
ss << R"m(
|
||||
|
||||
Otherwise, please refer to:
|
||||
|
||||
https://xgboost.readthedocs.io/en/stable/tutorials/dask.html#troubleshooting
|
||||
|
||||
for more info, or open an issue on GitHub. Starting from XGBoost 2.1.0, the PyPI package
|
||||
no long bundles NCCL in the binary wheel.
|
||||
|
||||
)m";
|
||||
auto help = ss.str();
|
||||
std::string msg{"Failed to load NCCL from path: `" + path_ + "`. Error:\n "};
|
||||
|
||||
auto safe_load = [&](auto t, StringView name) {
|
||||
std::stringstream errs;
|
||||
auto ptr = reinterpret_cast<decltype(t)>(dlsym(handle_, name.c_str()));
|
||||
if (!ptr) {
|
||||
errs << "Failed to load NCCL symbol `" << name << "` from " << path_ << ". Error:\n "
|
||||
<< dlerror() << help;
|
||||
LOG(FATAL) << errs.str();
|
||||
}
|
||||
return ptr;
|
||||
};
|
||||
|
||||
handle_ = dlopen(path_.c_str(), RTLD_LAZY);
|
||||
if (!handle_) {
|
||||
LOG(FATAL) << msg << dlerror() << help;
|
||||
}
|
||||
|
||||
allreduce_ = safe_load(allreduce_, "ncclAllReduce");
|
||||
broadcast_ = safe_load(broadcast_, "ncclBroadcast");
|
||||
allgather_ = safe_load(allgather_, "ncclAllGather");
|
||||
comm_init_rank_ = safe_load(comm_init_rank_, "ncclCommInitRank");
|
||||
comm_destroy_ = safe_load(comm_destroy_, "ncclCommDestroy");
|
||||
get_uniqueid_ = safe_load(get_uniqueid_, "ncclGetUniqueId");
|
||||
send_ = safe_load(send_, "ncclSend");
|
||||
recv_ = safe_load(recv_, "ncclRecv");
|
||||
group_start_ = safe_load(group_start_, "ncclGroupStart");
|
||||
group_end_ = safe_load(group_end_, "ncclGroupEnd");
|
||||
get_error_string_ = safe_load(get_error_string_, "ncclGetErrorString");
|
||||
get_version_ = safe_load(get_version_, "ncclGetVersion");
|
||||
|
||||
std::int32_t v;
|
||||
CHECK_EQ(get_version_(&v), ncclSuccess);
|
||||
auto patch = v % 100;
|
||||
auto minor = (v / 100) % 100;
|
||||
auto major = v / 10000;
|
||||
|
||||
LOG(INFO) << "Loaded shared NCCL " << major << "." << minor << "." << patch << ":`" << path_
|
||||
<< "`" << std::endl;
|
||||
#else
|
||||
allreduce_ = ncclAllReduce;
|
||||
broadcast_ = ncclBroadcast;
|
||||
allgather_ = ncclAllGather;
|
||||
comm_init_rank_ = ncclCommInitRank;
|
||||
comm_destroy_ = ncclCommDestroy;
|
||||
get_uniqueid_ = ncclGetUniqueId;
|
||||
send_ = ncclSend;
|
||||
recv_ = ncclRecv;
|
||||
group_start_ = ncclGroupStart;
|
||||
group_end_ = ncclGroupEnd;
|
||||
get_error_string_ = ncclGetErrorString;
|
||||
get_version_ = ncclGetVersion;
|
||||
#endif
|
||||
};
|
||||
|
||||
NcclStub::~NcclStub() { // NOLINT
|
||||
#if defined(XGBOOST_USE_DLOPEN_NCCL) || defined(XGBOOST_USE_DLOPEN_RCCL)
|
||||
if (handle_) {
|
||||
auto rc = dlclose(handle_);
|
||||
if (rc != 0) {
|
||||
LOG(WARNING) << "Failed to close NCCL handle:" << dlerror();
|
||||
}
|
||||
}
|
||||
handle_ = nullptr;
|
||||
#endif // defined(XGBOOST_USE_DLOPEN_NCCL)
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
#endif // defined(XGBOOST_USE_NCCL)
|
||||
86
src/collective/nccl_stub.h
Normal file
86
src/collective/nccl_stub.h
Normal file
@@ -0,0 +1,86 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#if defined(XGBOOST_USE_NCCL) || (defined(XGBOOST_USE_RCCL) && 0)
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <nccl.h>
|
||||
|
||||
#include <string> // for string
|
||||
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/string_view.h" // for StringView
|
||||
|
||||
namespace xgboost::collective {
|
||||
/**
|
||||
* @brief A stub for NCCL to facilitate dynamic loading.
|
||||
*/
|
||||
class NcclStub {
|
||||
#if defined(XGBOOST_USE_DLOPEN_NCCL) || defined(XGBOOST_USE_DLOPEN_RCCL)
|
||||
void* handle_{nullptr};
|
||||
#endif // defined(XGBOOST_USE_DLOPEN_NCCL)
|
||||
std::string path_;
|
||||
|
||||
decltype(ncclAllReduce)* allreduce_{nullptr};
|
||||
decltype(ncclBroadcast)* broadcast_{nullptr};
|
||||
decltype(ncclAllGather)* allgather_{nullptr};
|
||||
decltype(ncclCommInitRank)* comm_init_rank_{nullptr};
|
||||
decltype(ncclCommDestroy)* comm_destroy_{nullptr};
|
||||
decltype(ncclGetUniqueId)* get_uniqueid_{nullptr};
|
||||
decltype(ncclSend)* send_{nullptr};
|
||||
decltype(ncclRecv)* recv_{nullptr};
|
||||
decltype(ncclGroupStart)* group_start_{nullptr};
|
||||
decltype(ncclGroupEnd)* group_end_{nullptr};
|
||||
decltype(ncclGetErrorString)* get_error_string_{nullptr};
|
||||
decltype(ncclGetVersion)* get_version_{nullptr};
|
||||
|
||||
public:
|
||||
Result GetNcclResult(ncclResult_t code) const;
|
||||
|
||||
public:
|
||||
explicit NcclStub(StringView path);
|
||||
~NcclStub();
|
||||
|
||||
[[nodiscard]] Result Allreduce(const void* sendbuff, void* recvbuff, size_t count,
|
||||
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
cudaStream_t stream) const {
|
||||
return this->GetNcclResult(allreduce_(sendbuff, recvbuff, count, datatype, op, comm, stream));
|
||||
}
|
||||
[[nodiscard]] Result Broadcast(const void* sendbuff, void* recvbuff, size_t count,
|
||||
ncclDataType_t datatype, int root, ncclComm_t comm,
|
||||
cudaStream_t stream) const {
|
||||
return this->GetNcclResult(broadcast_(sendbuff, recvbuff, count, datatype, root, comm, stream));
|
||||
}
|
||||
[[nodiscard]] Result Allgather(const void* sendbuff, void* recvbuff, size_t sendcount,
|
||||
ncclDataType_t datatype, ncclComm_t comm,
|
||||
cudaStream_t stream) const {
|
||||
return this->GetNcclResult(allgather_(sendbuff, recvbuff, sendcount, datatype, comm, stream));
|
||||
}
|
||||
[[nodiscard]] Result CommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId,
|
||||
int rank) const {
|
||||
return this->GetNcclResult(this->comm_init_rank_(comm, nranks, commId, rank));
|
||||
}
|
||||
[[nodiscard]] Result CommDestroy(ncclComm_t comm) const {
|
||||
return this->GetNcclResult(comm_destroy_(comm));
|
||||
}
|
||||
[[nodiscard]] Result GetUniqueId(ncclUniqueId* uniqueId) const {
|
||||
return this->GetNcclResult(get_uniqueid_(uniqueId));
|
||||
}
|
||||
[[nodiscard]] Result Send(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer,
|
||||
ncclComm_t comm, cudaStream_t stream) {
|
||||
return this->GetNcclResult(send_(sendbuff, count, datatype, peer, comm, stream));
|
||||
}
|
||||
[[nodiscard]] Result Recv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer,
|
||||
ncclComm_t comm, cudaStream_t stream) const {
|
||||
return this->GetNcclResult(recv_(recvbuff, count, datatype, peer, comm, stream));
|
||||
}
|
||||
[[nodiscard]] Result GroupStart() const { return this->GetNcclResult(group_start_()); }
|
||||
[[nodiscard]] Result GroupEnd() const { return this->GetNcclResult(group_end_()); }
|
||||
[[nodiscard]] const char* GetErrorString(ncclResult_t result) const {
|
||||
return get_error_string_(result);
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
|
||||
#endif // defined(XGBOOST_USE_NCCL)
|
||||
@@ -58,36 +58,35 @@ Result Tracker::WaitUntilReady() const {
|
||||
|
||||
RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr)
|
||||
: sock_{std::move(sock)} {
|
||||
auto host = addr.Addr();
|
||||
|
||||
std::int32_t rank{0};
|
||||
rc_ = Success()
|
||||
<< [&] { return proto::Magic{}.Verify(&sock_); }
|
||||
<< [&] { return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_); };
|
||||
if (!rc_.OK()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::string cmd;
|
||||
sock_.Recv(&cmd);
|
||||
auto jcmd = Json::Load(StringView{cmd});
|
||||
cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"]));
|
||||
Json jcmd;
|
||||
std::int32_t port{0};
|
||||
if (cmd_ == proto::CMD::kStart) {
|
||||
proto::Start start;
|
||||
rc_ = start.TrackerHandle(jcmd, &world_, world, &port, &sock_, &eport_);
|
||||
} else if (cmd_ == proto::CMD::kPrint) {
|
||||
proto::Print print;
|
||||
rc_ = print.TrackerHandle(jcmd, &msg_);
|
||||
} else if (cmd_ == proto::CMD::kError) {
|
||||
proto::ErrorCMD error;
|
||||
rc_ = error.TrackerHandle(jcmd, &msg_, &code_);
|
||||
}
|
||||
if (!rc_.OK()) {
|
||||
return;
|
||||
}
|
||||
|
||||
info_ = proto::PeerInfo{host, port, rank};
|
||||
rc_ = Success() << [&] { return proto::Magic{}.Verify(&sock_); } << [&] {
|
||||
return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_);
|
||||
} << [&] {
|
||||
std::string cmd;
|
||||
sock_.Recv(&cmd);
|
||||
jcmd = Json::Load(StringView{cmd});
|
||||
cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"]));
|
||||
return Success();
|
||||
} << [&] {
|
||||
if (cmd_ == proto::CMD::kStart) {
|
||||
proto::Start start;
|
||||
return start.TrackerHandle(jcmd, &world_, world, &port, &sock_, &eport_);
|
||||
} else if (cmd_ == proto::CMD::kPrint) {
|
||||
proto::Print print;
|
||||
return print.TrackerHandle(jcmd, &msg_);
|
||||
} else if (cmd_ == proto::CMD::kError) {
|
||||
proto::ErrorCMD error;
|
||||
return error.TrackerHandle(jcmd, &msg_, &code_);
|
||||
}
|
||||
return Success();
|
||||
} << [&] {
|
||||
auto host = addr.Addr();
|
||||
info_ = proto::PeerInfo{host, port, rank};
|
||||
return Success();
|
||||
};
|
||||
}
|
||||
|
||||
RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
|
||||
@@ -137,15 +136,18 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
|
||||
std::int32_t n_shutdown{0};
|
||||
bool during_restart{false};
|
||||
bool running{false};
|
||||
std::vector<WorkerProxy> pending;
|
||||
|
||||
explicit State(std::int32_t world) : n_workers{world} {}
|
||||
State(State const& that) = delete;
|
||||
State& operator=(State&& that) = delete;
|
||||
|
||||
// modifiers
|
||||
void Start(WorkerProxy&& worker) {
|
||||
CHECK_LT(pending.size(), n_workers);
|
||||
CHECK_LE(n_shutdown, n_workers);
|
||||
CHECK(!running);
|
||||
|
||||
pending.emplace_back(std::forward<WorkerProxy>(worker));
|
||||
|
||||
@@ -155,6 +157,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
CHECK_GE(n_shutdown, 0);
|
||||
CHECK_LT(n_shutdown, n_workers);
|
||||
|
||||
running = false;
|
||||
++n_shutdown;
|
||||
|
||||
CHECK_LE(n_shutdown, n_workers);
|
||||
@@ -163,21 +166,26 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
CHECK_LE(pending.size(), n_workers);
|
||||
CHECK_LE(n_shutdown, n_workers);
|
||||
|
||||
running = false;
|
||||
during_restart = true;
|
||||
}
|
||||
[[nodiscard]] bool Ready() const {
|
||||
CHECK_LE(pending.size(), n_workers);
|
||||
return static_cast<std::int32_t>(pending.size()) == n_workers;
|
||||
}
|
||||
void Bootstrap() {
|
||||
CHECK_EQ(pending.size(), n_workers);
|
||||
CHECK_LE(n_shutdown, n_workers);
|
||||
|
||||
running = true;
|
||||
|
||||
// A reset.
|
||||
n_shutdown = 0;
|
||||
during_restart = false;
|
||||
pending.clear();
|
||||
}
|
||||
|
||||
// observers
|
||||
[[nodiscard]] bool Ready() const {
|
||||
CHECK_LE(pending.size(), n_workers);
|
||||
return static_cast<std::int32_t>(pending.size()) == n_workers;
|
||||
}
|
||||
[[nodiscard]] bool ShouldContinue() const {
|
||||
CHECK_LE(pending.size(), n_workers);
|
||||
CHECK_LE(n_shutdown, n_workers);
|
||||
@@ -187,7 +195,31 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
}
|
||||
};
|
||||
|
||||
return std::async(std::launch::async, [this] {
|
||||
auto handle_error = [&](WorkerProxy const& worker) {
|
||||
auto msg = worker.Msg();
|
||||
auto code = worker.Code();
|
||||
LOG(WARNING) << "Recieved error from [" << worker.Host() << ":" << worker.Rank() << "]: " << msg
|
||||
<< " code:" << code;
|
||||
auto host = worker.Host();
|
||||
// We signal all workers for the error, if they haven't aborted already.
|
||||
for (auto& w : worker_error_handles_) {
|
||||
if (w.first == host) {
|
||||
continue;
|
||||
}
|
||||
TCPSocket out;
|
||||
// Connecting to the error port as a signal for exit.
|
||||
//
|
||||
// retry is set to 1, just let the worker timeout or error. Otherwise the
|
||||
// tracker and the worker might be waiting for each other.
|
||||
auto rc = Connect(w.first, w.second, 1, timeout_, &out);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to inform workers to stop.");
|
||||
}
|
||||
}
|
||||
return Success();
|
||||
};
|
||||
|
||||
return std::async(std::launch::async, [this, handle_error] {
|
||||
State state{this->n_workers_};
|
||||
|
||||
while (state.ShouldContinue()) {
|
||||
@@ -205,6 +237,16 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
}
|
||||
switch (worker.Command()) {
|
||||
case proto::CMD::kStart: {
|
||||
if (state.running) {
|
||||
// Something went wrong with one of the workers. It got disconnected without
|
||||
// notice.
|
||||
state.Error();
|
||||
rc = handle_error(worker);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to handle abort.", std::move(rc));
|
||||
}
|
||||
}
|
||||
|
||||
state.Start(std::move(worker));
|
||||
if (state.Ready()) {
|
||||
rc = this->Bootstrap(&state.pending);
|
||||
@@ -216,36 +258,20 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
continue;
|
||||
}
|
||||
case proto::CMD::kShutdown: {
|
||||
if (state.during_restart) {
|
||||
// The worker can still send shutdown after call to `std::exit`.
|
||||
continue;
|
||||
}
|
||||
state.Shutdown();
|
||||
continue;
|
||||
}
|
||||
case proto::CMD::kError: {
|
||||
if (state.during_restart) {
|
||||
// Ignore further errors.
|
||||
continue;
|
||||
}
|
||||
state.Error();
|
||||
auto msg = worker.Msg();
|
||||
auto code = worker.Code();
|
||||
LOG(WARNING) << "Recieved error from [" << worker.Host() << ":" << worker.Rank()
|
||||
<< "]: " << msg << " code:" << code;
|
||||
auto host = worker.Host();
|
||||
// We signal all workers for the error, if they haven't aborted already.
|
||||
for (auto& w : worker_error_handles_) {
|
||||
if (w.first == host) {
|
||||
continue;
|
||||
}
|
||||
TCPSocket out;
|
||||
// retry is set to 1, just let the worker timeout or error. Otherwise the
|
||||
// tracker and the worker might be waiting for each other.
|
||||
auto rc = Connect(w.first, w.second, 1, timeout_, &out);
|
||||
// send signal to stop the worker.
|
||||
proto::ShutdownCMD shutdown;
|
||||
rc = shutdown.Send(&out);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to inform workers to stop.");
|
||||
}
|
||||
}
|
||||
|
||||
rc = handle_error(worker);
|
||||
continue;
|
||||
}
|
||||
case proto::CMD::kPrint: {
|
||||
|
||||
@@ -114,6 +114,9 @@ class RabitTracker : public Tracker {
|
||||
// record for how to reach out to workers if error happens.
|
||||
std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_;
|
||||
// listening socket for incoming workers.
|
||||
//
|
||||
// At the moment, the listener calls accept without first polling. We can add an
|
||||
// additional unix domain socket to allow cancelling the accept.
|
||||
TCPSocket listener_;
|
||||
|
||||
Result Bootstrap(std::vector<WorkerProxy>* p_workers);
|
||||
|
||||
Reference in New Issue
Block a user