merge latest changes

This commit is contained in:
Hui Liu
2023-12-13 21:06:28 -08:00
194 changed files with 4859 additions and 2838 deletions

View File

@@ -26,18 +26,19 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
}
for (std::int32_t r = 0; r < world; ++r) {
auto send_rank = (rank + world - r + worker_off) % world;
auto send_off = send_rank * segment_size;
send_off = std::min(send_off, data.size_bytes());
auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off));
next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
auto recv_off = recv_rank * segment_size;
recv_off = std::min(recv_off, data.size_bytes());
auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off));
prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
auto rc = prev_ch->Block();
auto rc = Success() << [&] {
auto send_rank = (rank + world - r + worker_off) % world;
auto send_off = send_rank * segment_size;
send_off = std::min(send_off, data.size_bytes());
auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off));
return next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
} << [&] {
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
auto recv_off = recv_rank * segment_size;
recv_off = std::min(recv_off, data.size_bytes());
auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off));
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
} << [&] { return prev_ch->Block(); };
if (!rc.OK()) {
return rc;
}
@@ -78,19 +79,19 @@ namespace detail {
auto next_ch = comm.Chan(next);
for (std::int32_t r = 0; r < world; ++r) {
auto send_rank = (rank + world - r) % world;
auto send_off = offset[send_rank];
auto send_size = sizes[send_rank];
auto send_seg = erased_result.subspan(send_off, send_size);
next_ch->SendAll(send_seg);
auto recv_rank = (rank + world - r - 1) % world;
auto recv_off = offset[recv_rank];
auto recv_size = sizes[recv_rank];
auto recv_seg = erased_result.subspan(recv_off, recv_size);
prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
auto rc = prev_ch->Block();
auto rc = Success() << [&] {
auto send_rank = (rank + world - r) % world;
auto send_off = offset[send_rank];
auto send_size = sizes[send_rank];
auto send_seg = erased_result.subspan(send_off, send_size);
return next_ch->SendAll(send_seg);
} << [&] {
auto recv_rank = (rank + world - r - 1) % world;
auto recv_off = offset[recv_rank];
auto recv_size = sizes[recv_rank];
auto recv_seg = erased_result.subspan(recv_off, recv_size);
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
} << [&] { return prev_ch->Block(); };
if (!rc.OK()) {
return rc;
}

View File

@@ -6,6 +6,7 @@
#include <algorithm> // for min
#include <cstddef> // for size_t
#include <cstdint> // for int32_t, int8_t
#include <utility> // for move
#include <vector> // for vector
#include "../data/array_interface.h" // for Type, DispatchDType
@@ -36,7 +37,10 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg);
auto send_seg = data.subspan(send_off, seg_nbytes);
next_ch->SendAll(send_seg);
auto rc = next_ch->SendAll(send_seg);
if (!rc.OK()) {
return rc;
}
// receive from ring prev
auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg;
@@ -46,8 +50,7 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
auto recv_seg = data.subspan(recv_off, seg_nbytes);
auto seg = s_buf.subspan(0, recv_seg.size());
prev_ch->RecvAll(seg);
auto rc = prev_ch->Block();
rc = std::move(rc) << [&] { return prev_ch->RecvAll(seg); } << [&] { return comm.Block(); };
if (!rc.OK()) {
return rc;
}
@@ -62,6 +65,9 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func const& op,
ArrayInterfaceHandler::Type type) {
if (comm.World() == 1) {
return Success();
}
return DispatchDType(type, [&](auto t) {
using T = decltype(t);
// Divide the data into segments according to the number of workers.
@@ -80,11 +86,9 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
auto prev_ch = comm.Chan(prev);
auto next_ch = comm.Chan(next);
rc = RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
if (!rc.OK()) {
return rc;
}
return comm.Block();
return std::move(rc) << [&] {
return RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
} << [&] { return comm.Block(); };
});
}
} // namespace xgboost::collective::cpu_impl

View File

@@ -62,8 +62,8 @@ Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t
if (shifted_rank != 0) { // not root
auto parent = ShiftRight(ShiftedParentRank(shifted_rank, depth), world, root);
comm.Chan(parent)->RecvAll(data);
auto rc = comm.Chan(parent)->Block();
auto rc = Success() << [&] { return comm.Chan(parent)->RecvAll(data); }
<< [&] { return comm.Chan(parent)->Block(); };
if (!rc.OK()) {
return Fail("broadcast failed.", std::move(rc));
}
@@ -75,7 +75,10 @@ Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t
auto sft_peer = shifted_rank + (1 << i);
auto peer = ShiftRight(sft_peer, world, root);
CHECK_NE(peer, root);
comm.Chan(peer)->SendAll(data);
auto rc = comm.Chan(peer)->SendAll(data);
if (!rc.OK()) {
return rc;
}
}
}

View File

@@ -23,25 +23,6 @@ Coll* Coll::MakeCUDAVar() { return new NCCLColl{}; }
NCCLColl::~NCCLColl() = default;
namespace {
Result GetNCCLResult(ncclResult_t code) {
if (code == ncclSuccess) {
return Success();
}
std::stringstream ss;
ss << "NCCL failure: " << ncclGetErrorString(code) << ".";
if (code == ncclUnhandledCudaError) {
// nccl usually preserves the last error so we can get more details.
auto err = cudaPeekAtLastError();
ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n";
} else if (code == ncclSystemError) {
ss << " This might be caused by a network configuration issue. Please consider specifying "
"the network interface for NCCL via environment variables listed in its reference: "
"`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n";
}
return Fail(ss.str());
}
auto GetNCCLType(ArrayInterfaceHandler::Type type) {
auto fatal = [] {
LOG(FATAL) << "Invalid type for NCCL operation.";
@@ -98,11 +79,12 @@ void RunBitwiseAllreduce(dh::CUDAStreamView stream, common::Span<std::int8_t> ou
common::Span<std::int8_t> data, Op op) {
dh::device_vector<std::int8_t> buffer(data.size() * pcomm->World());
auto* device_buffer = buffer.data().get();
auto stub = pcomm->Stub();
// First gather data from all the workers.
CHECK(handle);
auto rc = GetNCCLResult(
ncclAllGather(data.data(), device_buffer, data.size(), ncclInt8, handle, pcomm->Stream()));
auto rc =
stub->Allgather(data.data(), device_buffer, data.size(), ncclInt8, handle, pcomm->Stream());
if (!rc.OK()) {
return rc;
}
@@ -153,6 +135,8 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
}
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
CHECK(nccl);
auto stub = nccl->Stub();
return Success() << [&] {
if (IsBitwiseOp(op)) {
return BitwiseAllReduce(nccl, nccl->Handle(), data, op);
@@ -160,9 +144,8 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
return DispatchDType(type, [=](auto t) {
using T = decltype(t);
auto rdata = common::RestoreType<T>(data);
auto rc = ncclAllReduce(data.data(), data.data(), rdata.size(), GetNCCLType(type),
GetNCCLRedOp(op), nccl->Handle(), nccl->Stream());
return GetNCCLResult(rc);
return stub->Allreduce(data.data(), data.data(), rdata.size(), GetNCCLType(type),
GetNCCLRedOp(op), nccl->Handle(), nccl->Stream());
});
}
} << [&] { return nccl->Block(); };
@@ -175,9 +158,11 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
}
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
CHECK(nccl);
auto stub = nccl->Stub();
return Success() << [&] {
return GetNCCLResult(ncclBroadcast(data.data(), data.data(), data.size_bytes(), ncclInt8, root,
nccl->Handle(), nccl->Stream()));
return stub->Broadcast(data.data(), data.data(), data.size_bytes(), ncclInt8, root,
nccl->Handle(), nccl->Stream());
} << [&] { return nccl->Block(); };
}
@@ -188,10 +173,12 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
}
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
CHECK(nccl);
auto stub = nccl->Stub();
auto send = data.subspan(comm.Rank() * size, size);
return Success() << [&] {
return GetNCCLResult(
ncclAllGather(send.data(), data.data(), size, ncclInt8, nccl->Handle(), nccl->Stream()));
return stub->Allgather(send.data(), data.data(), size, ncclInt8, nccl->Handle(),
nccl->Stream());
} << [&] { return nccl->Block(); };
}
@@ -203,19 +190,20 @@ namespace cuda_impl {
*/
Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes, common::Span<std::int8_t> recv) {
return Success() << [] { return GetNCCLResult(ncclGroupStart()); } << [&] {
auto stub = comm->Stub();
return Success() << [&stub] { return stub->GroupStart(); } << [&] {
std::size_t offset = 0;
for (std::int32_t r = 0; r < comm->World(); ++r) {
auto as_bytes = sizes[r];
auto rc = ncclBroadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes,
ncclInt8, r, comm->Handle(), dh::DefaultStream());
if (rc != ncclSuccess) {
return GetNCCLResult(rc);
auto rc = stub->Broadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes,
ncclInt8, r, comm->Handle(), dh::DefaultStream());
if (!rc.OK()) {
return rc;
}
offset += as_bytes;
}
return Success();
} << [] { return GetNCCLResult(ncclGroupEnd()); };
} << [&] { return stub->GroupEnd(); };
}
} // namespace cuda_impl
@@ -228,10 +216,11 @@ Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const>
if (!comm.IsDistributed()) {
return Success();
}
auto stub = nccl->Stub();
switch (algo) {
case AllgatherVAlgo::kRing: {
return Success() << [] { return GetNCCLResult(ncclGroupStart()); } << [&] {
return Success() << [&] { return stub->GroupStart(); } << [&] {
// get worker offset
detail::AllgatherVOffset(sizes, recv_segments);
// copy data
@@ -241,8 +230,8 @@ Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const>
cudaMemcpyDeviceToDevice, nccl->Stream()));
}
return detail::RingAllgatherV(comm, sizes, recv_segments, recv);
} << [] {
return GetNCCLResult(ncclGroupEnd());
} << [&] {
return stub->GroupEnd();
} << [&] { return nccl->Block(); };
}
case AllgatherVAlgo::kBcast: {

View File

@@ -8,7 +8,8 @@
#include "../data/array_interface.h" // for ArrayInterfaceHandler
#include "coll.h" // for Coll
#include "comm.h" // for Comm
#include "xgboost/span.h" // for Span
#include "nccl_stub.h"
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
class NCCLColl : public Coll {

View File

@@ -5,6 +5,7 @@
#include <algorithm> // for copy
#include <chrono> // for seconds
#include <cstdlib> // for exit
#include <memory> // for shared_ptr
#include <string> // for string
#include <utility> // for move, forward
@@ -29,19 +30,28 @@ Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds time
Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
std::string const& task_id, TCPSocket* out, std::int32_t rank,
std::int32_t world) {
// get information from tracker
// Get information from the tracker
CHECK(!info.host.empty());
auto rc = Connect(info.host, info.port, retry, timeout, out);
if (!rc.OK()) {
return Fail("Failed to connect to the tracker.", std::move(rc));
}
TCPSocket& tracker = *out;
return std::move(rc)
<< [&] { return tracker.NonBlocking(false); }
<< [&] { return tracker.RecvTimeout(timeout); }
<< [&] { return proto::Magic{}.Verify(&tracker); }
<< [&] { return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id); };
return Success() << [&] {
auto rc = Connect(info.host, info.port, retry, timeout, out);
if (rc.OK()) {
return rc;
} else {
return Fail("Failed to connect to the tracker.", std::move(rc));
}
} << [&] {
return tracker.NonBlocking(false);
} << [&] {
return tracker.RecvTimeout(timeout);
} << [&] {
return proto::Magic{}.Verify(&tracker);
} << [&] {
return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id);
} << [&] {
LOG(INFO) << "Task " << task_id << " connected to the tracker";
return Success();
};
}
[[nodiscard]] Result Comm::ConnectTracker(TCPSocket* out) const {
@@ -49,14 +59,6 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
this->Rank(), this->World());
}
#if !defined(XGBOOST_USE_NCCL) && !defined(XGBOOST_USE_RCCL)
Comm* Comm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
common::AssertGPUSupport();
common::AssertNCCLSupport();
return nullptr;
}
#endif // !defined(XGBOOST_USE_NCCL)
[[nodiscard]] Result ConnectWorkers(Comm const& comm, TCPSocket* listener, std::int32_t lport,
proto::PeerInfo ninfo, std::chrono::seconds timeout,
std::int32_t retry,
@@ -181,12 +183,21 @@ Comm* Comm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
}
RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t retry, std::string task_id)
: Comm{std::move(host), port, timeout, retry, std::move(task_id)} {
std::int32_t retry, std::string task_id, StringView nccl_path)
: HostComm{std::move(host), port, timeout, retry, std::move(task_id)},
nccl_path_{std::move(nccl_path)} {
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
CHECK(rc.OK()) << rc.Report();
}
#if !defined(XGBOOST_USE_NCCL) && !defined(XGBOOST_USE_RCCL)
Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
common::AssertGPUSupport();
common::AssertNCCLSupport();
return nullptr;
}
#endif // !defined(XGBOOST_USE_NCCL)
[[nodiscard]] Result RabitComm::Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
std::string task_id) {
TCPSocket tracker;
@@ -209,24 +220,18 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
std::shared_ptr<TCPSocket> error_sock{TCPSocket::CreatePtr(domain)};
auto eport = error_sock->BindHost();
error_sock->Listen();
error_worker_ = std::thread{[this, error_sock = std::move(error_sock)] {
error_worker_ = std::thread{[error_sock = std::move(error_sock)] {
auto conn = error_sock->Accept();
// On Windows accept returns an invalid socket after network is shutdown.
// On Windows, accept returns a closed socket after finalize.
if (conn.IsClosed()) {
return;
}
LOG(WARNING) << "Another worker is running into error.";
std::string scmd;
conn.Recv(&scmd);
auto jcmd = Json::Load(scmd);
auto rc = this->Shutdown();
if (!rc.OK()) {
LOG(WARNING) << "Fail to shutdown worker:" << rc.Report();
}
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
exit(-1);
// exit is nicer than abort as the former performs cleanups.
std::exit(-1);
#else
LOG(FATAL) << rc.Report();
LOG(FATAL) << "abort";
#endif
}};
error_worker_.detach();
@@ -259,8 +264,8 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
CHECK(this->channels_.empty());
for (auto& w : workers) {
if (w) {
w->SetNoDelay();
rc = w->NonBlocking(true);
rc = std::move(rc) << [&] { return w->SetNoDelay(); } << [&] { return w->NonBlocking(true); }
<< [&] { return w->SetKeepAlive(); };
}
if (!rc.OK()) {
return rc;

View File

@@ -10,21 +10,24 @@
#include <sstream> // for stringstream
#include <vector> // for vector
#include "../common/cuda_context.cuh" // for CUDAContext
#include "../common/device_helpers.cuh" // for DefaultStream
#include "../common/type.h" // for EraseType
#include "broadcast.h" // for Broadcast
#include "comm.cuh" // for NCCLComm
#include "comm.h" // for Comm
#include "nccl_stub.h" // for NcclStub
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
namespace {
Result GetUniqueId(Comm const& comm, std::shared_ptr<Coll> coll, ncclUniqueId* pid) {
Result GetUniqueId(Comm const& comm, std::shared_ptr<NcclStub> stub, std::shared_ptr<Coll> coll,
ncclUniqueId* pid) {
static const int kRootRank = 0;
ncclUniqueId id;
if (comm.Rank() == kRootRank) {
dh::safe_nccl(ncclGetUniqueId(&id));
auto rc = stub->GetUniqueId(&id);
CHECK(rc.OK()) << rc.Report();
}
auto rc = coll->Broadcast(
comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)}, kRootRank);
@@ -63,14 +66,15 @@ static std::string PrintUUID(xgboost::common::Span<std::uint64_t, kUuidLength> c
}
} // namespace
Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const {
return new NCCLComm{ctx, *this, pimpl};
Comm* RabitComm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const {
return new NCCLComm{ctx, *this, pimpl, StringView{this->nccl_path_}};
}
NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl)
NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl,
StringView nccl_path)
: Comm{root.TrackerInfo().host, root.TrackerInfo().port, root.Timeout(), root.Retry(),
root.TaskID()},
stream_{dh::DefaultStream()} {
stream_{ctx->CUDACtx()->Stream()} {
this->world_ = root.World();
this->rank_ = root.Rank();
this->domain_ = root.Domain();
@@ -79,6 +83,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
}
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
stub_ = std::make_shared<NcclStub>(nccl_path);
std::vector<std::uint64_t> uuids(root.World() * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<std::uint64_t>{uuids.data(), uuids.size()};
@@ -104,19 +109,22 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
rc = GetUniqueId(root, pimpl, &nccl_unique_id_);
rc = std::move(rc) << [&] { return GetUniqueId(root, this->stub_, pimpl, &nccl_unique_id_); } <<
[&] {
return this->stub_->CommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank());
};
CHECK(rc.OK()) << rc.Report();
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank()));
for (std::int32_t r = 0; r < root.World(); ++r) {
this->channels_.emplace_back(
std::make_shared<NCCLChannel>(root, r, nccl_comm_, dh::DefaultStream()));
std::make_shared<NCCLChannel>(root, r, nccl_comm_, stub_, dh::DefaultStream()));
}
}
NCCLComm::~NCCLComm() {
if (nccl_comm_) {
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
auto rc = stub_->CommDestroy(nccl_comm_);
CHECK(rc.OK()) << rc.Report();
}
}
} // namespace xgboost::collective

View File

@@ -9,9 +9,13 @@
#include "../common/cuda_to_hip.h"
#include "rccl.h"
#endif // XGBOOST_USE_NCCL
#include <utility> // for move
#include "../common/device_helpers.cuh"
#include "coll.h"
#include "comm.h"
#include "nccl_stub.h" // for NcclStub
#include "xgboost/context.h"
namespace xgboost::collective {
@@ -24,15 +28,20 @@ inline Result GetCUDAResult(cudaError rc) {
return Fail(msg);
}
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
class NCCLComm : public Comm {
ncclComm_t nccl_comm_{nullptr};
std::shared_ptr<NcclStub> stub_;
ncclUniqueId nccl_unique_id_{};
dh::CUDAStreamView stream_;
std::string nccl_path_;
public:
[[nodiscard]] ncclComm_t Handle() const { return nccl_comm_; }
auto Stub() const { return stub_; }
explicit NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl);
explicit NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl,
StringView nccl_path);
[[nodiscard]] Result LogTracker(std::string) const override {
LOG(FATAL) << "Device comm is used for logging.";
return Fail("Undefined.");
@@ -49,22 +58,29 @@ class NCCLComm : public Comm {
class NCCLChannel : public Channel {
std::int32_t rank_{-1};
ncclComm_t nccl_comm_{};
std::shared_ptr<NcclStub> stub_;
dh::CUDAStreamView stream_;
public:
explicit NCCLChannel(Comm const& comm, std::int32_t rank, ncclComm_t nccl_comm,
dh::CUDAStreamView stream)
: rank_{rank}, nccl_comm_{nccl_comm}, Channel{comm, nullptr}, stream_{stream} {}
std::shared_ptr<NcclStub> stub, dh::CUDAStreamView stream)
: rank_{rank},
nccl_comm_{nccl_comm},
stub_{std::move(stub)},
Channel{comm, nullptr},
stream_{stream} {}
void SendAll(std::int8_t const* ptr, std::size_t n) override {
dh::safe_nccl(ncclSend(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
[[nodiscard]] Result SendAll(std::int8_t const* ptr, std::size_t n) override {
return stub_->Send(ptr, n, ncclInt8, rank_, nccl_comm_, stream_);
}
void RecvAll(std::int8_t* ptr, std::size_t n) override {
dh::safe_nccl(ncclRecv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
[[nodiscard]] Result RecvAll(std::int8_t* ptr, std::size_t n) override {
return stub_->Recv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_);
}
[[nodiscard]] Result Block() override {
auto rc = stream_.Sync(false);
return GetCUDAResult(rc);
}
};
#endif // defined(XGBOOST_USE_NCCL)
} // namespace xgboost::collective

View File

@@ -34,6 +34,8 @@ inline std::int32_t BootstrapPrev(std::int32_t r, std::int32_t world) {
return nrank;
}
inline StringView DefaultNcclName() { return "libnccl.so.2"; }
class Channel;
class Coll;
@@ -86,11 +88,21 @@ class Comm : public std::enable_shared_from_this<Comm> {
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
virtual Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const;
};
class RabitComm : public Comm {
/**
* @brief Base class for CPU-based communicator.
*/
class HostComm : public Comm {
public:
using Comm::Comm;
[[nodiscard]] virtual Comm* MakeCUDAVar(Context const* ctx,
std::shared_ptr<Coll> pimpl) const = 0;
};
class RabitComm : public HostComm {
std::string nccl_path_ = std::string{DefaultNcclName()};
[[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
std::string task_id);
[[nodiscard]] Result Shutdown();
@@ -100,13 +112,15 @@ class RabitComm : public Comm {
RabitComm() = default;
// ctor for testing where environment is known.
RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t retry, std::string task_id);
std::int32_t retry, std::string task_id, StringView nccl_path);
~RabitComm() noexcept(false) override;
[[nodiscard]] bool IsFederated() const override { return false; }
[[nodiscard]] Result LogTracker(std::string msg) const override;
[[nodiscard]] Result SignalError(Result const&) override;
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
};
/**
@@ -121,21 +135,25 @@ class Channel {
explicit Channel(Comm const& comm, std::shared_ptr<TCPSocket> sock)
: sock_{std::move(sock)}, comm_{comm} {}
virtual void SendAll(std::int8_t const* ptr, std::size_t n) {
[[nodiscard]] virtual Result SendAll(std::int8_t const* ptr, std::size_t n) {
Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast<std::int8_t*>(ptr), n, sock_.get(), 0};
CHECK(sock_.get());
comm_.Submit(std::move(op));
return Success();
}
void SendAll(common::Span<std::int8_t const> data) {
this->SendAll(data.data(), data.size_bytes());
[[nodiscard]] Result SendAll(common::Span<std::int8_t const> data) {
return this->SendAll(data.data(), data.size_bytes());
}
virtual void RecvAll(std::int8_t* ptr, std::size_t n) {
[[nodiscard]] virtual Result RecvAll(std::int8_t* ptr, std::size_t n) {
Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0};
CHECK(sock_.get());
comm_.Submit(std::move(op));
return Success();
}
[[nodiscard]] Result RecvAll(common::Span<std::int8_t> data) {
return this->RecvAll(data.data(), data.size_bytes());
}
void RecvAll(common::Span<std::int8_t> data) { this->RecvAll(data.data(), data.size_bytes()); }
[[nodiscard]] auto Socket() const { return sock_; }
[[nodiscard]] virtual Result Block() { return comm_.Block(); }

View File

@@ -0,0 +1,122 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include "comm_group.h"
#include <algorithm> // for transform
#include <chrono> // for seconds
#include <cstdint> // for int32_t
#include <memory> // for shared_ptr, unique_ptr
#include <string> // for string
#include <vector> // for vector
#include "../common/json_utils.h" // for OptionalArg
#include "coll.h" // for Coll
#include "comm.h" // for Comm
#include "tracker.h" // for GetHostAddress
#include "xgboost/collective/result.h" // for Result
#include "xgboost/context.h" // for DeviceOrd
#include "xgboost/json.h" // for Json
#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_coll.h"
#include "../../plugin/federated/federated_comm.h"
#endif
namespace xgboost::collective {
[[nodiscard]] std::shared_ptr<Coll> CommGroup::Backend(DeviceOrd device) const {
if (device.IsCUDA()) {
if (!gpu_coll_) {
gpu_coll_.reset(backend_->MakeCUDAVar());
}
return gpu_coll_;
}
return backend_;
}
[[nodiscard]] Comm const& CommGroup::Ctx(Context const* ctx, DeviceOrd device) const {
if (device.IsCUDA()) {
CHECK(ctx->IsCUDA());
if (!gpu_comm_ || gpu_comm_->World() != comm_->World()) {
gpu_comm_.reset(comm_->MakeCUDAVar(ctx, backend_));
}
return *gpu_comm_;
}
return *comm_;
}
CommGroup::CommGroup()
: comm_{std::shared_ptr<RabitComm>(new RabitComm{})}, // NOLINT
backend_{std::shared_ptr<Coll>(new Coll{})} {} // NOLINT
[[nodiscard]] CommGroup* CommGroup::Create(Json config) {
if (IsA<Null>(config)) {
return new CommGroup;
}
std::string type = OptionalArg<String>(config, "dmlc_communicator", std::string{"rabit"});
// Try both lower and upper case for compatibility
auto get_param = [&](std::string name, auto dft, auto t) {
std::string upper;
std::transform(name.cbegin(), name.cend(), std::back_inserter(upper),
[](char c) { return std::toupper(c); });
std::transform(name.cbegin(), name.cend(), name.begin(),
[](char c) { return std::tolower(c); });
auto const& obj = get<Object const>(config);
auto it = obj.find(upper);
if (it != obj.cend()) {
return OptionalArg<decltype(t)>(config, upper, dft);
} else {
return OptionalArg<decltype(t)>(config, name, dft);
}
};
// Common args
auto retry = get_param("dmlc_retry", static_cast<Integer::Int>(DefaultRetry()), Integer{});
auto timeout =
get_param("dmlc_timeout_sec", static_cast<Integer::Int>(DefaultTimeoutSec()), Integer{});
auto task_id = get_param("dmlc_task_id", std::string{}, String{});
if (type == "rabit") {
auto host = get_param("dmlc_tracker_uri", std::string{}, String{});
auto port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{});
auto nccl = get_param("dmlc_nccl_path", std::string{DefaultNcclName()}, String{});
auto ptr =
new CommGroup{std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT
host, static_cast<std::int32_t>(port), std::chrono::seconds{timeout},
static_cast<std::int32_t>(retry), task_id, nccl}},
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
return ptr;
} else if (type == "federated") {
#if defined(XGBOOST_USE_FEDERATED)
auto ptr = new CommGroup{
std::make_shared<FederatedComm>(retry, std::chrono::seconds{timeout}, task_id, config),
std::make_shared<FederatedColl>()};
return ptr;
#endif // defined(XGBOOST_USE_FEDERATED)
} else {
LOG(FATAL) << "Invalid communicator type";
}
return nullptr;
}
std::unique_ptr<collective::CommGroup>& GlobalCommGroup() {
static thread_local std::unique_ptr<collective::CommGroup> sptr;
if (!sptr) {
Json config{Null{}};
sptr.reset(CommGroup::Create(config));
}
return sptr;
}
void GlobalCommGroupInit(Json config) {
auto& sptr = GlobalCommGroup();
sptr.reset(CommGroup::Create(std::move(config)));
}
void GlobalCommGroupFinalize() {
auto& sptr = GlobalCommGroup();
sptr.reset();
}
} // namespace xgboost::collective

View File

@@ -0,0 +1,55 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <memory> // for shared_ptr, unique_ptr
#include <string> // for string
#include <utility> // for move
#include "coll.h" // for Comm
#include "comm.h" // for Coll
#include "xgboost/collective/result.h" // for Result
#include "xgboost/collective/socket.h" // for GetHostName
namespace xgboost::collective {
/**
* @brief Communicator group used for double dispatching between communicators and
* collective implementations.
*/
class CommGroup {
std::shared_ptr<HostComm> comm_;
mutable std::shared_ptr<Comm> gpu_comm_;
std::shared_ptr<Coll> backend_;
mutable std::shared_ptr<Coll> gpu_coll_; // lazy initialization
CommGroup(std::shared_ptr<Comm> comm, std::shared_ptr<Coll> coll)
: comm_{std::dynamic_pointer_cast<HostComm>(comm)}, backend_{std::move(coll)} {
CHECK(comm_);
}
public:
CommGroup();
[[nodiscard]] auto World() const { return comm_->World(); }
[[nodiscard]] auto Rank() const { return comm_->Rank(); }
[[nodiscard]] bool IsDistributed() const { return comm_->IsDistributed(); }
[[nodiscard]] static CommGroup* Create(Json config);
[[nodiscard]] std::shared_ptr<Coll> Backend(DeviceOrd device) const;
[[nodiscard]] Comm const& Ctx(Context const* ctx, DeviceOrd device) const;
[[nodiscard]] Result SignalError(Result const& res) { return comm_->SignalError(res); }
[[nodiscard]] Result ProcessorName(std::string* out) const {
auto rc = GetHostName(out);
return rc;
}
};
std::unique_ptr<collective::CommGroup>& GlobalCommGroup();
void GlobalCommGroupInit(Json config);
void GlobalCommGroupFinalize();
} // namespace xgboost::collective

View File

@@ -3,6 +3,7 @@
*/
#include "communicator.h"
#include "comm.h"
#include "in_memory_communicator.h"
#include "noop_communicator.h"
#include "rabit_communicator.h"
@@ -14,8 +15,12 @@
namespace xgboost::collective {
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
thread_local CommunicatorType Communicator::type_{};
thread_local std::string Communicator::nccl_path_{};
void Communicator::Init(Json const& config) {
auto nccl = OptionalArg<String>(config, "dmlc_nccl_path", std::string{DefaultNcclName()});
nccl_path_ = nccl;
auto type = GetTypeFromEnv();
auto const arg = GetTypeFromConfig(config);
if (arg != CommunicatorType::kUnknown) {

View File

@@ -31,17 +31,17 @@ DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
switch (type_) {
case CommunicatorType::kRabit:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_));
break;
case CommunicatorType::kFederated:
case CommunicatorType::kInMemory:
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
break;
case CommunicatorType::kInMemoryNccl:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true));
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true, nccl_path_));
break;
default:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_));
}
#else
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));

View File

@@ -234,6 +234,7 @@ class Communicator {
static thread_local std::unique_ptr<Communicator> communicator_;
static thread_local CommunicatorType type_;
static thread_local std::string nccl_path_;
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
static thread_local std::unique_ptr<DeviceCommunicator> device_communicator_;
#endif

View File

@@ -10,21 +10,26 @@
#include "xgboost/logging.h" // for CHECK
namespace xgboost::collective {
Result Loop::EmptyQueue() {
Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
timer_.Start(__func__);
auto error = [this] {
this->stop_ = true;
auto error = [this] { timer_.Stop(__func__); };
if (stop_) {
timer_.Stop(__func__);
};
return Success();
}
while (!queue_.empty() && !stop_) {
std::queue<Op> qcopy;
auto& qcopy = *p_queue;
// clear the copied queue
while (!qcopy.empty()) {
rabit::utils::PollHelper poll;
std::size_t n_ops = qcopy.size();
// watch all ops
while (!queue_.empty()) {
auto op = queue_.front();
queue_.pop();
// Iterate through all the ops for poll
for (std::size_t i = 0; i < n_ops; ++i) {
auto op = qcopy.front();
qcopy.pop();
switch (op.code) {
case Op::kRead: {
@@ -40,6 +45,7 @@ Result Loop::EmptyQueue() {
return Fail("Invalid socket operation.");
}
}
qcopy.push(op);
}
@@ -51,10 +57,12 @@ Result Loop::EmptyQueue() {
error();
return rc;
}
// we wonldn't be here if the queue is empty.
CHECK(!qcopy.empty());
while (!qcopy.empty() && !stop_) {
// Iterate through all the ops for performing the operations
for (std::size_t i = 0; i < n_ops; ++i) {
auto op = qcopy.front();
qcopy.pop();
@@ -81,20 +89,21 @@ Result Loop::EmptyQueue() {
}
if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) {
stop_ = true;
auto rc = system::FailWithCode("Invalid socket output.");
error();
return rc;
}
op.off += n_bytes_done;
CHECK_LE(op.off, op.n);
if (op.off != op.n) {
// not yet finished, push back to queue for next round.
queue_.push(op);
qcopy.push(op);
}
}
}
timer_.Stop(__func__);
return Success();
}
@@ -107,22 +116,46 @@ void Loop::Process() {
if (stop_) {
break;
}
CHECK(!mu_.try_lock());
this->rc_ = this->EmptyQueue();
if (!rc_.OK()) {
stop_ = true;
auto unlock_notify = [&](bool is_blocking, bool stop) {
if (!is_blocking) {
std::lock_guard guard{mu_};
stop_ = stop;
} else {
stop_ = stop;
lock.unlock();
}
cv_.notify_one();
break;
};
// move the queue
std::queue<Op> qcopy;
bool is_blocking = false;
while (!queue_.empty()) {
auto op = queue_.front();
queue_.pop();
if (op.code == Op::kBlock) {
is_blocking = true;
} else {
qcopy.push(op);
}
}
// unblock the queue
if (!is_blocking) {
lock.unlock();
}
// clear the queue
auto rc = this->EmptyQueue(&qcopy);
// Handle error
if (!rc.OK()) {
unlock_notify(is_blocking, true);
std::lock_guard<std::mutex> guard{rc_lock_};
this->rc_ = std::move(rc);
return;
}
CHECK(queue_.empty());
CHECK(!mu_.try_lock());
cv_.notify_one();
}
if (rc_.OK()) {
CHECK(queue_.empty());
CHECK(qcopy.empty());
unlock_notify(is_blocking, false);
}
}
@@ -140,6 +173,24 @@ Result Loop::Stop() {
return Success();
}
[[nodiscard]] Result Loop::Block() {
{
std::lock_guard<std::mutex> guard{rc_lock_};
if (!rc_.OK()) {
return std::move(rc_);
}
}
this->Submit(Op{Op::kBlock});
{
std::unique_lock lock{mu_};
cv_.wait(lock, [this] { return (this->queue_.empty()) || stop_; });
}
{
std::lock_guard<std::mutex> lock{rc_lock_};
return std::move(rc_);
}
}
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
timer_.Init(__func__);
worker_ = std::thread{[this] {

View File

@@ -20,13 +20,14 @@ namespace xgboost::collective {
class Loop {
public:
struct Op {
enum Code : std::int8_t { kRead = 0, kWrite = 1 } code;
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2 } code;
std::int32_t rank{-1};
std::int8_t* ptr{nullptr};
std::size_t n{0};
TCPSocket* sock{nullptr};
std::size_t off{0};
explicit Op(Code c) : code{c} { CHECK(c == kBlock); }
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
Op(Op const&) = default;
@@ -41,12 +42,15 @@ class Loop {
std::mutex mu_;
std::queue<Op> queue_;
std::chrono::seconds timeout_;
Result rc_;
std::mutex rc_lock_; // lock for transferring error info.
bool stop_{false};
std::exception_ptr curr_exce_{nullptr};
common::Monitor timer_;
common::Monitor mutable timer_;
Result EmptyQueue();
Result EmptyQueue(std::queue<Op>* p_queue) const;
void Process();
public:
@@ -60,15 +64,7 @@ class Loop {
cv_.notify_one();
}
[[nodiscard]] Result Block() {
{
std::unique_lock lock{mu_};
cv_.notify_all();
}
std::unique_lock lock{mu_};
cv_.wait(lock, [this] { return this->queue_.empty() || stop_; });
return std::move(rc_);
}
[[nodiscard]] Result Block();
explicit Loop(std::chrono::seconds timeout);

View File

@@ -2,12 +2,14 @@
* Copyright 2023 XGBoost contributors
*/
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
#include "comm.cuh"
#include "nccl_device_communicator.cuh"
namespace xgboost {
namespace collective {
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync)
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync,
StringView nccl_path)
: device_ordinal_{device_ordinal},
needs_sync_{needs_sync},
world_size_{GetWorldSize()},
@@ -18,6 +20,7 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sy
if (world_size_ == 1) {
return;
}
stub_ = std::make_shared<NcclStub>(std::move(nccl_path));
std::vector<uint64_t> uuids(world_size_ * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
@@ -43,7 +46,8 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sy
nccl_unique_id_ = GetUniqueId();
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
auto rc = stub_->CommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_);
CHECK(rc.OK()) << rc.Report();
}
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
@@ -51,7 +55,8 @@ NcclDeviceCommunicator::~NcclDeviceCommunicator() {
return;
}
if (nccl_comm_) {
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
auto rc = stub_->CommDestroy(nccl_comm_);
CHECK(rc.OK()) << rc.Report();
}
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
LOG(CONSOLE) << "======== NCCL Statistics========";
@@ -137,8 +142,9 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
auto *device_buffer = buffer.data().get();
// First gather data from all the workers.
dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
nccl_comm_, dh::DefaultStream()));
auto rc = stub_->Allgather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
nccl_comm_, dh::DefaultStream());
CHECK(rc.OK()) << rc.Report();
if (needs_sync_) {
dh::DefaultStream().Sync();
}
@@ -170,9 +176,10 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co
if (IsBitwiseOp(op)) {
BitwiseAllReduce(send_receive_buffer, count, data_type, op);
} else {
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
dh::DefaultStream()));
auto rc = stub_->Allreduce(send_receive_buffer, send_receive_buffer, count,
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
dh::DefaultStream());
CHECK(rc.OK()) << rc.Report();
}
allreduce_bytes_ += count * GetTypeSize(data_type);
allreduce_calls_ += 1;
@@ -185,8 +192,9 @@ void NcclDeviceCommunicator::AllGather(void const *send_buffer, void *receive_bu
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllGather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_,
dh::DefaultStream()));
auto rc = stub_->Allgather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_,
dh::DefaultStream());
CHECK(rc.OK()) << rc.Report();
}
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
@@ -206,14 +214,18 @@ void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_b
receive_buffer->resize(total_bytes);
size_t offset = 0;
dh::safe_nccl(ncclGroupStart());
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, dh::DefaultStream()));
offset += as_bytes;
}
dh::safe_nccl(ncclGroupEnd());
auto rc = Success() << [&] { return stub_->GroupStart(); } << [&] {
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
auto rc = stub_->Broadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, dh::DefaultStream());
if (!rc.OK()) {
return rc;
}
offset += as_bytes;
}
return Success();
} << [&] { return stub_->GroupEnd(); };
}
void NcclDeviceCommunicator::Synchronize() {

View File

@@ -4,8 +4,10 @@
#pragma once
#include "../common/device_helpers.cuh"
#include "comm.cuh"
#include "communicator.h"
#include "device_communicator.cuh"
#include "nccl_stub.h"
namespace xgboost {
namespace collective {
@@ -25,7 +27,7 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
* needed. The in-memory communicator is used in tests with multiple threads, each thread
* representing a rank/worker, so the additional synchronization is needed to avoid deadlocks.
*/
explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync);
explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync, StringView nccl_path);
~NcclDeviceCommunicator() override;
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override;
@@ -74,7 +76,8 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
static const int kRootRank = 0;
ncclUniqueId id;
if (rank_ == kRootRank) {
dh::safe_nccl(ncclGetUniqueId(&id));
auto rc = stub_->GetUniqueId(&id);
CHECK(rc.OK()) << rc.Report();
}
Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
return id;
@@ -88,6 +91,7 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
int const world_size_;
int const rank_;
ncclComm_t nccl_comm_{};
std::shared_ptr<NcclStub> stub_;
ncclUniqueId nccl_unique_id_{};
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.

131
src/collective/nccl_stub.cc Normal file
View File

@@ -0,0 +1,131 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#if defined(XGBOOST_USE_NCCL) || (defined(XGBOOST_USE_RCCL) && 0)
#include "nccl_stub.h"
#include <cuda.h> // for CUDA_VERSION
#include <cuda_runtime_api.h> // for cudaPeekAtLastError
#include <dlfcn.h> // for dlclose, dlsym, dlopen
#include <nccl.h>
#include <thrust/system/cuda/error.h> // for cuda_category
#include <thrust/system_error.h> // for system_error
#include <cstdint> // for int32_t
#include <sstream> // for stringstream
#include <string> // for string
#include <utility> // for move
#include "xgboost/logging.h"
namespace xgboost::collective {
Result NcclStub::GetNcclResult(ncclResult_t code) const {
if (code == ncclSuccess) {
return Success();
}
std::stringstream ss;
ss << "NCCL failure: " << this->GetErrorString(code) << ".";
if (code == ncclUnhandledCudaError) {
// nccl usually preserves the last error so we can get more details.
auto err = cudaPeekAtLastError();
ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n";
} else if (code == ncclSystemError) {
ss << " This might be caused by a network configuration issue. Please consider specifying "
"the network interface for NCCL via environment variables listed in its reference: "
"`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n";
}
return Fail(ss.str());
}
NcclStub::NcclStub(StringView path) : path_{std::move(path)} {
#if defined(XGBOOST_USE_DLOPEN_NCCL) || defined(XGBOOST_USE_DLOPEN_RCCL)
CHECK(!path_.empty()) << "Empty path for NCCL.";
auto cu_major = (CUDA_VERSION) / 1000;
std::stringstream ss;
ss << R"m(
If XGBoost is installed from PyPI with pip, the error can fixed by:
- Run `pip install nvidia-nccl-cu)m"
<< cu_major << "` (Or with any CUDA version that's compatible with " << cu_major << ").";
ss << R"m(
Otherwise, please refer to:
https://xgboost.readthedocs.io/en/stable/tutorials/dask.html#troubleshooting
for more info, or open an issue on GitHub. Starting from XGBoost 2.1.0, the PyPI package
no long bundles NCCL in the binary wheel.
)m";
auto help = ss.str();
std::string msg{"Failed to load NCCL from path: `" + path_ + "`. Error:\n "};
auto safe_load = [&](auto t, StringView name) {
std::stringstream errs;
auto ptr = reinterpret_cast<decltype(t)>(dlsym(handle_, name.c_str()));
if (!ptr) {
errs << "Failed to load NCCL symbol `" << name << "` from " << path_ << ". Error:\n "
<< dlerror() << help;
LOG(FATAL) << errs.str();
}
return ptr;
};
handle_ = dlopen(path_.c_str(), RTLD_LAZY);
if (!handle_) {
LOG(FATAL) << msg << dlerror() << help;
}
allreduce_ = safe_load(allreduce_, "ncclAllReduce");
broadcast_ = safe_load(broadcast_, "ncclBroadcast");
allgather_ = safe_load(allgather_, "ncclAllGather");
comm_init_rank_ = safe_load(comm_init_rank_, "ncclCommInitRank");
comm_destroy_ = safe_load(comm_destroy_, "ncclCommDestroy");
get_uniqueid_ = safe_load(get_uniqueid_, "ncclGetUniqueId");
send_ = safe_load(send_, "ncclSend");
recv_ = safe_load(recv_, "ncclRecv");
group_start_ = safe_load(group_start_, "ncclGroupStart");
group_end_ = safe_load(group_end_, "ncclGroupEnd");
get_error_string_ = safe_load(get_error_string_, "ncclGetErrorString");
get_version_ = safe_load(get_version_, "ncclGetVersion");
std::int32_t v;
CHECK_EQ(get_version_(&v), ncclSuccess);
auto patch = v % 100;
auto minor = (v / 100) % 100;
auto major = v / 10000;
LOG(INFO) << "Loaded shared NCCL " << major << "." << minor << "." << patch << ":`" << path_
<< "`" << std::endl;
#else
allreduce_ = ncclAllReduce;
broadcast_ = ncclBroadcast;
allgather_ = ncclAllGather;
comm_init_rank_ = ncclCommInitRank;
comm_destroy_ = ncclCommDestroy;
get_uniqueid_ = ncclGetUniqueId;
send_ = ncclSend;
recv_ = ncclRecv;
group_start_ = ncclGroupStart;
group_end_ = ncclGroupEnd;
get_error_string_ = ncclGetErrorString;
get_version_ = ncclGetVersion;
#endif
};
NcclStub::~NcclStub() { // NOLINT
#if defined(XGBOOST_USE_DLOPEN_NCCL) || defined(XGBOOST_USE_DLOPEN_RCCL)
if (handle_) {
auto rc = dlclose(handle_);
if (rc != 0) {
LOG(WARNING) << "Failed to close NCCL handle:" << dlerror();
}
}
handle_ = nullptr;
#endif // defined(XGBOOST_USE_DLOPEN_NCCL)
}
} // namespace xgboost::collective
#endif // defined(XGBOOST_USE_NCCL)

View File

@@ -0,0 +1,86 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#if defined(XGBOOST_USE_NCCL) || (defined(XGBOOST_USE_RCCL) && 0)
#include <cuda_runtime_api.h>
#include <nccl.h>
#include <string> // for string
#include "xgboost/collective/result.h" // for Result
#include "xgboost/string_view.h" // for StringView
namespace xgboost::collective {
/**
* @brief A stub for NCCL to facilitate dynamic loading.
*/
class NcclStub {
#if defined(XGBOOST_USE_DLOPEN_NCCL) || defined(XGBOOST_USE_DLOPEN_RCCL)
void* handle_{nullptr};
#endif // defined(XGBOOST_USE_DLOPEN_NCCL)
std::string path_;
decltype(ncclAllReduce)* allreduce_{nullptr};
decltype(ncclBroadcast)* broadcast_{nullptr};
decltype(ncclAllGather)* allgather_{nullptr};
decltype(ncclCommInitRank)* comm_init_rank_{nullptr};
decltype(ncclCommDestroy)* comm_destroy_{nullptr};
decltype(ncclGetUniqueId)* get_uniqueid_{nullptr};
decltype(ncclSend)* send_{nullptr};
decltype(ncclRecv)* recv_{nullptr};
decltype(ncclGroupStart)* group_start_{nullptr};
decltype(ncclGroupEnd)* group_end_{nullptr};
decltype(ncclGetErrorString)* get_error_string_{nullptr};
decltype(ncclGetVersion)* get_version_{nullptr};
public:
Result GetNcclResult(ncclResult_t code) const;
public:
explicit NcclStub(StringView path);
~NcclStub();
[[nodiscard]] Result Allreduce(const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
cudaStream_t stream) const {
return this->GetNcclResult(allreduce_(sendbuff, recvbuff, count, datatype, op, comm, stream));
}
[[nodiscard]] Result Broadcast(const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, int root, ncclComm_t comm,
cudaStream_t stream) const {
return this->GetNcclResult(broadcast_(sendbuff, recvbuff, count, datatype, root, comm, stream));
}
[[nodiscard]] Result Allgather(const void* sendbuff, void* recvbuff, size_t sendcount,
ncclDataType_t datatype, ncclComm_t comm,
cudaStream_t stream) const {
return this->GetNcclResult(allgather_(sendbuff, recvbuff, sendcount, datatype, comm, stream));
}
[[nodiscard]] Result CommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId,
int rank) const {
return this->GetNcclResult(this->comm_init_rank_(comm, nranks, commId, rank));
}
[[nodiscard]] Result CommDestroy(ncclComm_t comm) const {
return this->GetNcclResult(comm_destroy_(comm));
}
[[nodiscard]] Result GetUniqueId(ncclUniqueId* uniqueId) const {
return this->GetNcclResult(get_uniqueid_(uniqueId));
}
[[nodiscard]] Result Send(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer,
ncclComm_t comm, cudaStream_t stream) {
return this->GetNcclResult(send_(sendbuff, count, datatype, peer, comm, stream));
}
[[nodiscard]] Result Recv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer,
ncclComm_t comm, cudaStream_t stream) const {
return this->GetNcclResult(recv_(recvbuff, count, datatype, peer, comm, stream));
}
[[nodiscard]] Result GroupStart() const { return this->GetNcclResult(group_start_()); }
[[nodiscard]] Result GroupEnd() const { return this->GetNcclResult(group_end_()); }
[[nodiscard]] const char* GetErrorString(ncclResult_t result) const {
return get_error_string_(result);
}
};
} // namespace xgboost::collective
#endif // defined(XGBOOST_USE_NCCL)

View File

@@ -58,36 +58,35 @@ Result Tracker::WaitUntilReady() const {
RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr)
: sock_{std::move(sock)} {
auto host = addr.Addr();
std::int32_t rank{0};
rc_ = Success()
<< [&] { return proto::Magic{}.Verify(&sock_); }
<< [&] { return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_); };
if (!rc_.OK()) {
return;
}
std::string cmd;
sock_.Recv(&cmd);
auto jcmd = Json::Load(StringView{cmd});
cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"]));
Json jcmd;
std::int32_t port{0};
if (cmd_ == proto::CMD::kStart) {
proto::Start start;
rc_ = start.TrackerHandle(jcmd, &world_, world, &port, &sock_, &eport_);
} else if (cmd_ == proto::CMD::kPrint) {
proto::Print print;
rc_ = print.TrackerHandle(jcmd, &msg_);
} else if (cmd_ == proto::CMD::kError) {
proto::ErrorCMD error;
rc_ = error.TrackerHandle(jcmd, &msg_, &code_);
}
if (!rc_.OK()) {
return;
}
info_ = proto::PeerInfo{host, port, rank};
rc_ = Success() << [&] { return proto::Magic{}.Verify(&sock_); } << [&] {
return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_);
} << [&] {
std::string cmd;
sock_.Recv(&cmd);
jcmd = Json::Load(StringView{cmd});
cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"]));
return Success();
} << [&] {
if (cmd_ == proto::CMD::kStart) {
proto::Start start;
return start.TrackerHandle(jcmd, &world_, world, &port, &sock_, &eport_);
} else if (cmd_ == proto::CMD::kPrint) {
proto::Print print;
return print.TrackerHandle(jcmd, &msg_);
} else if (cmd_ == proto::CMD::kError) {
proto::ErrorCMD error;
return error.TrackerHandle(jcmd, &msg_, &code_);
}
return Success();
} << [&] {
auto host = addr.Addr();
info_ = proto::PeerInfo{host, port, rank};
return Success();
};
}
RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
@@ -137,15 +136,18 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
std::int32_t n_shutdown{0};
bool during_restart{false};
bool running{false};
std::vector<WorkerProxy> pending;
explicit State(std::int32_t world) : n_workers{world} {}
State(State const& that) = delete;
State& operator=(State&& that) = delete;
// modifiers
void Start(WorkerProxy&& worker) {
CHECK_LT(pending.size(), n_workers);
CHECK_LE(n_shutdown, n_workers);
CHECK(!running);
pending.emplace_back(std::forward<WorkerProxy>(worker));
@@ -155,6 +157,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
CHECK_GE(n_shutdown, 0);
CHECK_LT(n_shutdown, n_workers);
running = false;
++n_shutdown;
CHECK_LE(n_shutdown, n_workers);
@@ -163,21 +166,26 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
CHECK_LE(pending.size(), n_workers);
CHECK_LE(n_shutdown, n_workers);
running = false;
during_restart = true;
}
[[nodiscard]] bool Ready() const {
CHECK_LE(pending.size(), n_workers);
return static_cast<std::int32_t>(pending.size()) == n_workers;
}
void Bootstrap() {
CHECK_EQ(pending.size(), n_workers);
CHECK_LE(n_shutdown, n_workers);
running = true;
// A reset.
n_shutdown = 0;
during_restart = false;
pending.clear();
}
// observers
[[nodiscard]] bool Ready() const {
CHECK_LE(pending.size(), n_workers);
return static_cast<std::int32_t>(pending.size()) == n_workers;
}
[[nodiscard]] bool ShouldContinue() const {
CHECK_LE(pending.size(), n_workers);
CHECK_LE(n_shutdown, n_workers);
@@ -187,7 +195,31 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
}
};
return std::async(std::launch::async, [this] {
auto handle_error = [&](WorkerProxy const& worker) {
auto msg = worker.Msg();
auto code = worker.Code();
LOG(WARNING) << "Recieved error from [" << worker.Host() << ":" << worker.Rank() << "]: " << msg
<< " code:" << code;
auto host = worker.Host();
// We signal all workers for the error, if they haven't aborted already.
for (auto& w : worker_error_handles_) {
if (w.first == host) {
continue;
}
TCPSocket out;
// Connecting to the error port as a signal for exit.
//
// retry is set to 1, just let the worker timeout or error. Otherwise the
// tracker and the worker might be waiting for each other.
auto rc = Connect(w.first, w.second, 1, timeout_, &out);
if (!rc.OK()) {
return Fail("Failed to inform workers to stop.");
}
}
return Success();
};
return std::async(std::launch::async, [this, handle_error] {
State state{this->n_workers_};
while (state.ShouldContinue()) {
@@ -205,6 +237,16 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
}
switch (worker.Command()) {
case proto::CMD::kStart: {
if (state.running) {
// Something went wrong with one of the workers. It got disconnected without
// notice.
state.Error();
rc = handle_error(worker);
if (!rc.OK()) {
return Fail("Failed to handle abort.", std::move(rc));
}
}
state.Start(std::move(worker));
if (state.Ready()) {
rc = this->Bootstrap(&state.pending);
@@ -216,36 +258,20 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
continue;
}
case proto::CMD::kShutdown: {
if (state.during_restart) {
// The worker can still send shutdown after call to `std::exit`.
continue;
}
state.Shutdown();
continue;
}
case proto::CMD::kError: {
if (state.during_restart) {
// Ignore further errors.
continue;
}
state.Error();
auto msg = worker.Msg();
auto code = worker.Code();
LOG(WARNING) << "Recieved error from [" << worker.Host() << ":" << worker.Rank()
<< "]: " << msg << " code:" << code;
auto host = worker.Host();
// We signal all workers for the error, if they haven't aborted already.
for (auto& w : worker_error_handles_) {
if (w.first == host) {
continue;
}
TCPSocket out;
// retry is set to 1, just let the worker timeout or error. Otherwise the
// tracker and the worker might be waiting for each other.
auto rc = Connect(w.first, w.second, 1, timeout_, &out);
// send signal to stop the worker.
proto::ShutdownCMD shutdown;
rc = shutdown.Send(&out);
if (!rc.OK()) {
return Fail("Failed to inform workers to stop.");
}
}
rc = handle_error(worker);
continue;
}
case proto::CMD::kPrint: {

View File

@@ -114,6 +114,9 @@ class RabitTracker : public Tracker {
// record for how to reach out to workers if error happens.
std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_;
// listening socket for incoming workers.
//
// At the moment, the listener calls accept without first polling. We can add an
// additional unix domain socket to allow cancelling the accept.
TCPSocket listener_;
Result Bootstrap(std::vector<WorkerProxy>* p_workers);