Cleanup code for distributed training. (#9805)

* Cleanup code for distributed training.

- Merge `GetNcclResult` into nccl stub.
- Split up utilities from the main dask module.
- Let Channel return `Result` to accommodate nccl channel.
- Remove old `use_label_encoder` parameter.
This commit is contained in:
Jiaming Yuan 2023-11-25 09:10:56 +08:00 committed by GitHub
parent e9260de3f3
commit 8fe1a2213c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 221 additions and 192 deletions

View File

@ -94,6 +94,8 @@ from xgboost.sklearn import (
from xgboost.tracker import RabitTracker, get_host_ip
from xgboost.training import train as worker_train
from .utils import get_n_threads
if TYPE_CHECKING:
import dask
import distributed
@ -908,6 +910,34 @@ async def _check_workers_are_alive(
raise RuntimeError(f"Missing required workers: {missing_workers}")
def _get_dmatrices(
train_ref: dict,
train_id: int,
*refs: dict,
evals_id: Sequence[int],
evals_name: Sequence[str],
n_threads: int,
) -> Tuple[DMatrix, List[Tuple[DMatrix, str]]]:
Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads)
evals: List[Tuple[DMatrix, str]] = []
for i, ref in enumerate(refs):
if evals_id[i] == train_id:
evals.append((Xy, evals_name[i]))
continue
if ref.get("ref", None) is not None:
if ref["ref"] != train_id:
raise ValueError(
"The training DMatrix should be used as a reference to evaluation"
" `QuantileDMatrix`."
)
del ref["ref"]
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads, ref=Xy)
else:
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads)
evals.append((eval_Xy, evals_name[i]))
return Xy, evals
async def _train_async(
client: "distributed.Client",
global_config: Dict[str, Any],
@ -940,41 +970,20 @@ async def _train_async(
) -> Optional[TrainReturnT]:
worker = distributed.get_worker()
local_param = parameters.copy()
n_threads = 0
# dask worker nthreads, "state" is available in 2022.6.1
dwnt = worker.state.nthreads if hasattr(worker, "state") else worker.nthreads
for p in ["nthread", "n_jobs"]:
if (
local_param.get(p, None) is not None
and local_param.get(p, dwnt) != dwnt
):
LOGGER.info("Overriding `nthreads` defined in dask worker.")
n_threads = local_param[p]
break
if n_threads == 0 or n_threads is None:
n_threads = dwnt
n_threads = get_n_threads(local_param, worker)
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
local_history: TrainingCallback.EvalsLog = {}
with CommunicatorContext(**rabit_args), config.config_context(**global_config):
Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads)
evals: List[Tuple[DMatrix, str]] = []
for i, ref in enumerate(refs):
if evals_id[i] == train_id:
evals.append((Xy, evals_name[i]))
continue
if ref.get("ref", None) is not None:
if ref["ref"] != train_id:
raise ValueError(
"The training DMatrix should be used as a reference"
" to evaluation `QuantileDMatrix`."
Xy, evals = _get_dmatrices(
train_ref,
train_id,
*refs,
evals_id=evals_id,
evals_name=evals_name,
n_threads=n_threads,
)
del ref["ref"]
eval_Xy = _dmatrix_from_list_of_parts(
**ref, nthread=n_threads, ref=Xy
)
else:
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads)
evals.append((eval_Xy, evals_name[i]))
booster = worker_train(
params=local_param,

View File

@ -0,0 +1,24 @@
"""Utilities for the XGBoost Dask interface."""
import logging
from typing import TYPE_CHECKING, Any, Dict
LOGGER = logging.getLogger("[xgboost.dask]")
if TYPE_CHECKING:
import distributed
def get_n_threads(local_param: Dict[str, Any], worker: "distributed.Worker") -> int:
"""Get the number of threads from a worker and the user-supplied parameters."""
# dask worker nthreads, "state" is available in 2022.6.1
dwnt = worker.state.nthreads if hasattr(worker, "state") else worker.nthreads
n_threads = None
for p in ["nthread", "n_jobs"]:
if local_param.get(p, None) is not None and local_param.get(p, dwnt) != dwnt:
LOGGER.info("Overriding `nthreads` defined in dask worker.")
n_threads = local_param[p]
break
if n_threads == 0 or n_threads is None:
n_threads = dwnt
return n_threads

View File

@ -808,7 +808,6 @@ class XGBModel(XGBModelBase):
"kwargs",
"missing",
"n_estimators",
"use_label_encoder",
"enable_categorical",
"early_stopping_rounds",
"callbacks",

View File

@ -138,7 +138,6 @@ _inverse_pyspark_param_alias_map = {v: k for k, v in _pyspark_param_alias_map.it
_unsupported_xgb_params = [
"gpu_id", # we have "device" pyspark param instead.
"enable_categorical", # Use feature_types param to specify categorical feature instead
"use_label_encoder",
"n_jobs", # Do not allow user to set it, will use `spark.task.cpus` value instead.
"nthread", # Ditto
]

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2023 by XGBoost Contributors
* Copyright 2019-2023, XGBoost Contributors
*/
#include <thrust/transform.h> // for transform
@ -15,6 +15,9 @@
#include "xgboost/data.h"
#include "xgboost/json.h"
#include "xgboost/learner.h"
#if defined(XGBOOST_USE_NCCL)
#include <nccl.h>
#endif
namespace xgboost {
void XGBBuildInfoDevice(Json *p_info) {

View File

@ -26,18 +26,19 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
}
for (std::int32_t r = 0; r < world; ++r) {
auto rc = Success() << [&] {
auto send_rank = (rank + world - r + worker_off) % world;
auto send_off = send_rank * segment_size;
send_off = std::min(send_off, data.size_bytes());
auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off));
next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
return next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
} << [&] {
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
auto recv_off = recv_rank * segment_size;
recv_off = std::min(recv_off, data.size_bytes());
auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off));
prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
auto rc = prev_ch->Block();
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
} << [&] { return prev_ch->Block(); };
if (!rc.OK()) {
return rc;
}
@ -78,19 +79,19 @@ namespace detail {
auto next_ch = comm.Chan(next);
for (std::int32_t r = 0; r < world; ++r) {
auto rc = Success() << [&] {
auto send_rank = (rank + world - r) % world;
auto send_off = offset[send_rank];
auto send_size = sizes[send_rank];
auto send_seg = erased_result.subspan(send_off, send_size);
next_ch->SendAll(send_seg);
return next_ch->SendAll(send_seg);
} << [&] {
auto recv_rank = (rank + world - r - 1) % world;
auto recv_off = offset[recv_rank];
auto recv_size = sizes[recv_rank];
auto recv_seg = erased_result.subspan(recv_off, recv_size);
prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
auto rc = prev_ch->Block();
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
} << [&] { return prev_ch->Block(); };
if (!rc.OK()) {
return rc;
}

View File

@ -37,7 +37,10 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg);
auto send_seg = data.subspan(send_off, seg_nbytes);
next_ch->SendAll(send_seg);
auto rc = next_ch->SendAll(send_seg);
if (!rc.OK()) {
return rc;
}
// receive from ring prev
auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg;
@ -47,8 +50,7 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
auto recv_seg = data.subspan(recv_off, seg_nbytes);
auto seg = s_buf.subspan(0, recv_seg.size());
prev_ch->RecvAll(seg);
auto rc = comm.Block();
rc = std::move(rc) << [&] { return prev_ch->RecvAll(seg); } << [&] { return comm.Block(); };
if (!rc.OK()) {
return rc;
}

View File

@ -62,8 +62,8 @@ Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t
if (shifted_rank != 0) { // not root
auto parent = ShiftRight(ShiftedParentRank(shifted_rank, depth), world, root);
comm.Chan(parent)->RecvAll(data);
auto rc = comm.Chan(parent)->Block();
auto rc = Success() << [&] { return comm.Chan(parent)->RecvAll(data); }
<< [&] { return comm.Chan(parent)->Block(); };
if (!rc.OK()) {
return Fail("broadcast failed.", std::move(rc));
}
@ -75,7 +75,10 @@ Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t
auto sft_peer = shifted_rank + (1 << i);
auto peer = ShiftRight(sft_peer, world, root);
CHECK_NE(peer, root);
comm.Chan(peer)->SendAll(data);
auto rc = comm.Chan(peer)->SendAll(data);
if (!rc.OK()) {
return rc;
}
}
}

View File

@ -79,8 +79,8 @@ void RunBitwiseAllreduce(dh::CUDAStreamView stream, common::Span<std::int8_t> ou
// First gather data from all the workers.
CHECK(handle);
auto rc = GetNCCLResult(stub, stub->Allgather(data.data(), device_buffer, data.size(), ncclInt8,
handle, pcomm->Stream()));
auto rc =
stub->Allgather(data.data(), device_buffer, data.size(), ncclInt8, handle, pcomm->Stream());
if (!rc.OK()) {
return rc;
}
@ -140,9 +140,8 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
return DispatchDType(type, [=](auto t) {
using T = decltype(t);
auto rdata = common::RestoreType<T>(data);
auto rc = stub->Allreduce(data.data(), data.data(), rdata.size(), GetNCCLType(type),
return stub->Allreduce(data.data(), data.data(), rdata.size(), GetNCCLType(type),
GetNCCLRedOp(op), nccl->Handle(), nccl->Stream());
return GetNCCLResult(stub, rc);
});
}
} << [&] { return nccl->Block(); };
@ -158,8 +157,8 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
auto stub = nccl->Stub();
return Success() << [&] {
return GetNCCLResult(stub, stub->Broadcast(data.data(), data.data(), data.size_bytes(),
ncclInt8, root, nccl->Handle(), nccl->Stream()));
return stub->Broadcast(data.data(), data.data(), data.size_bytes(), ncclInt8, root,
nccl->Handle(), nccl->Stream());
} << [&] { return nccl->Block(); };
}
@ -174,8 +173,8 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
auto send = data.subspan(comm.Rank() * size, size);
return Success() << [&] {
return GetNCCLResult(stub, stub->Allgather(send.data(), data.data(), size, ncclInt8,
nccl->Handle(), nccl->Stream()));
return stub->Allgather(send.data(), data.data(), size, ncclInt8, nccl->Handle(),
nccl->Stream());
} << [&] { return nccl->Block(); };
}
@ -188,19 +187,19 @@ namespace cuda_impl {
Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes, common::Span<std::int8_t> recv) {
auto stub = comm->Stub();
return Success() << [&stub] { return GetNCCLResult(stub, stub->GroupStart()); } << [&] {
return Success() << [&stub] { return stub->GroupStart(); } << [&] {
std::size_t offset = 0;
for (std::int32_t r = 0; r < comm->World(); ++r) {
auto as_bytes = sizes[r];
auto rc = stub->Broadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes,
ncclInt8, r, comm->Handle(), dh::DefaultStream());
if (rc != ncclSuccess) {
return GetNCCLResult(stub, rc);
if (!rc.OK()) {
return rc;
}
offset += as_bytes;
}
return Success();
} << [&] { return GetNCCLResult(stub, stub->GroupEnd()); };
} << [&] { return stub->GroupEnd(); };
}
} // namespace cuda_impl
@ -217,7 +216,7 @@ Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const>
switch (algo) {
case AllgatherVAlgo::kRing: {
return Success() << [&] { return GetNCCLResult(stub, stub->GroupStart()); } << [&] {
return Success() << [&] { return stub->GroupStart(); } << [&] {
// get worker offset
detail::AllgatherVOffset(sizes, recv_segments);
// copy data
@ -228,7 +227,7 @@ Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const>
}
return detail::RingAllgatherV(comm, sizes, recv_segments, recv);
} << [&] {
return GetNCCLResult(stub, stub->GroupEnd());
return stub->GroupEnd();
} << [&] { return nccl->Block(); };
}
case AllgatherVAlgo::kBcast: {

View File

@ -26,7 +26,7 @@ Result GetUniqueId(Comm const& comm, std::shared_ptr<NcclStub> stub, std::shared
static const int kRootRank = 0;
ncclUniqueId id;
if (comm.Rank() == kRootRank) {
auto rc = GetNCCLResult(stub, stub->GetUniqueId(&id));
auto rc = stub->GetUniqueId(&id);
CHECK(rc.OK()) << rc.Report();
}
auto rc = coll->Broadcast(
@ -99,11 +99,9 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
rc = std::move(rc) << [&] {
return GetUniqueId(root, this->stub_, pimpl, &nccl_unique_id_);
} << [&] {
return GetNCCLResult(this->stub_, this->stub_->CommInitRank(&nccl_comm_, root.World(),
nccl_unique_id_, root.Rank()));
rc = std::move(rc) << [&] { return GetUniqueId(root, this->stub_, pimpl, &nccl_unique_id_); } <<
[&] {
return this->stub_->CommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank());
};
CHECK(rc.OK()) << rc.Report();
@ -115,7 +113,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
NCCLComm::~NCCLComm() {
if (nccl_comm_) {
auto rc = GetNCCLResult(stub_, stub_->CommDestroy(nccl_comm_));
auto rc = stub_->CommDestroy(nccl_comm_);
CHECK(rc.OK()) << rc.Report();
}
}

View File

@ -52,25 +52,6 @@ class NCCLComm : public Comm {
}
};
inline Result GetNCCLResult(std::shared_ptr<NcclStub> stub, ncclResult_t code) {
if (code == ncclSuccess) {
return Success();
}
std::stringstream ss;
ss << "NCCL failure: " << stub->GetErrorString(code) << ".";
if (code == ncclUnhandledCudaError) {
// nccl usually preserves the last error so we can get more details.
auto err = cudaPeekAtLastError();
ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n";
} else if (code == ncclSystemError) {
ss << " This might be caused by a network configuration issue. Please consider specifying "
"the network interface for NCCL via environment variables listed in its reference: "
"`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n";
}
return Fail(ss.str());
}
class NCCLChannel : public Channel {
std::int32_t rank_{-1};
ncclComm_t nccl_comm_{};
@ -86,13 +67,11 @@ class NCCLChannel : public Channel {
Channel{comm, nullptr},
stream_{stream} {}
void SendAll(std::int8_t const* ptr, std::size_t n) override {
auto rc = GetNCCLResult(stub_, stub_->Send(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
CHECK(rc.OK()) << rc.Report();
[[nodiscard]] Result SendAll(std::int8_t const* ptr, std::size_t n) override {
return stub_->Send(ptr, n, ncclInt8, rank_, nccl_comm_, stream_);
}
void RecvAll(std::int8_t* ptr, std::size_t n) override {
auto rc = GetNCCLResult(stub_, stub_->Recv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
CHECK(rc.OK()) << rc.Report();
[[nodiscard]] Result RecvAll(std::int8_t* ptr, std::size_t n) override {
return stub_->Recv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_);
}
[[nodiscard]] Result Block() override {
auto rc = stream_.Sync(false);

View File

@ -135,21 +135,25 @@ class Channel {
explicit Channel(Comm const& comm, std::shared_ptr<TCPSocket> sock)
: sock_{std::move(sock)}, comm_{comm} {}
virtual void SendAll(std::int8_t const* ptr, std::size_t n) {
[[nodiscard]] virtual Result SendAll(std::int8_t const* ptr, std::size_t n) {
Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast<std::int8_t*>(ptr), n, sock_.get(), 0};
CHECK(sock_.get());
comm_.Submit(std::move(op));
return Success();
}
void SendAll(common::Span<std::int8_t const> data) {
this->SendAll(data.data(), data.size_bytes());
[[nodiscard]] Result SendAll(common::Span<std::int8_t const> data) {
return this->SendAll(data.data(), data.size_bytes());
}
virtual void RecvAll(std::int8_t* ptr, std::size_t n) {
[[nodiscard]] virtual Result RecvAll(std::int8_t* ptr, std::size_t n) {
Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0};
CHECK(sock_.get());
comm_.Submit(std::move(op));
return Success();
}
[[nodiscard]] Result RecvAll(common::Span<std::int8_t> data) {
return this->RecvAll(data.data(), data.size_bytes());
}
void RecvAll(common::Span<std::int8_t> data) { this->RecvAll(data.data(), data.size_bytes()); }
[[nodiscard]] auto Socket() const { return sock_; }
[[nodiscard]] virtual Result Block() { return comm_.Block(); }

View File

@ -46,8 +46,7 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sy
nccl_unique_id_ = GetUniqueId();
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto rc =
GetNCCLResult(stub_, stub_->CommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
auto rc = stub_->CommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_);
CHECK(rc.OK()) << rc.Report();
}
@ -56,7 +55,7 @@ NcclDeviceCommunicator::~NcclDeviceCommunicator() {
return;
}
if (nccl_comm_) {
auto rc = GetNCCLResult(stub_, stub_->CommDestroy(nccl_comm_));
auto rc = stub_->CommDestroy(nccl_comm_);
CHECK(rc.OK()) << rc.Report();
}
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
@ -143,9 +142,8 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
auto *device_buffer = buffer.data().get();
// First gather data from all the workers.
auto rc = GetNCCLResult(
stub_, stub_->Allgather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
nccl_comm_, dh::DefaultStream()));
auto rc = stub_->Allgather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
nccl_comm_, dh::DefaultStream());
CHECK(rc.OK()) << rc.Report();
if (needs_sync_) {
dh::DefaultStream().Sync();
@ -178,9 +176,9 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co
if (IsBitwiseOp(op)) {
BitwiseAllReduce(send_receive_buffer, count, data_type, op);
} else {
auto rc = GetNCCLResult(stub_, stub_->Allreduce(send_receive_buffer, send_receive_buffer, count,
GetNcclDataType(data_type), GetNcclRedOp(op),
nccl_comm_, dh::DefaultStream()));
auto rc = stub_->Allreduce(send_receive_buffer, send_receive_buffer, count,
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
dh::DefaultStream());
CHECK(rc.OK()) << rc.Report();
}
allreduce_bytes_ += count * GetTypeSize(data_type);
@ -194,8 +192,8 @@ void NcclDeviceCommunicator::AllGather(void const *send_buffer, void *receive_bu
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto rc = GetNCCLResult(stub_, stub_->Allgather(send_buffer, receive_buffer, send_size, ncclInt8,
nccl_comm_, dh::DefaultStream()));
auto rc = stub_->Allgather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_,
dh::DefaultStream());
CHECK(rc.OK()) << rc.Report();
}
@ -216,19 +214,18 @@ void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_b
receive_buffer->resize(total_bytes);
size_t offset = 0;
auto rc = Success() << [&] { return GetNCCLResult(stub_, stub_->GroupStart()); } << [&] {
auto rc = Success() << [&] { return stub_->GroupStart(); } << [&] {
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
auto rc = GetNCCLResult(
stub_, stub_->Broadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, dh::DefaultStream()));
auto rc = stub_->Broadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, dh::DefaultStream());
if (!rc.OK()) {
return rc;
}
offset += as_bytes;
}
return Success();
} << [&] { return GetNCCLResult(stub_, stub_->GroupEnd()); };
} << [&] { return stub_->GroupEnd(); };
}
void NcclDeviceCommunicator::Synchronize() {

View File

@ -66,7 +66,7 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
static const int kRootRank = 0;
ncclUniqueId id;
if (rank_ == kRootRank) {
auto rc = GetNCCLResult(stub_, stub_->GetUniqueId(&id));
auto rc = stub_->GetUniqueId(&id);
CHECK(rc.OK()) << rc.Report();
}
Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));

View File

@ -5,8 +5,11 @@
#include "nccl_stub.h"
#include <cuda.h> // for CUDA_VERSION
#include <cuda_runtime_api.h> // for cudaPeekAtLastError
#include <dlfcn.h> // for dlclose, dlsym, dlopen
#include <nccl.h>
#include <thrust/system/cuda/error.h> // for cuda_category
#include <thrust/system_error.h> // for system_error
#include <cstdint> // for int32_t
#include <sstream> // for stringstream
@ -16,6 +19,25 @@
#include "xgboost/logging.h"
namespace xgboost::collective {
Result NcclStub::GetNcclResult(ncclResult_t code) const {
if (code == ncclSuccess) {
return Success();
}
std::stringstream ss;
ss << "NCCL failure: " << this->GetErrorString(code) << ".";
if (code == ncclUnhandledCudaError) {
// nccl usually preserves the last error so we can get more details.
auto err = cudaPeekAtLastError();
ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n";
} else if (code == ncclSystemError) {
ss << " This might be caused by a network configuration issue. Please consider specifying "
"the network interface for NCCL via environment variables listed in its reference: "
"`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n";
}
return Fail(ss.str());
}
NcclStub::NcclStub(StringView path) : path_{std::move(path)} {
#if defined(XGBOOST_USE_DLOPEN_NCCL)
CHECK(!path_.empty()) << "Empty path for NCCL.";

View File

@ -8,9 +8,13 @@
#include <string> // for string
#include "xgboost/collective/result.h" // for Result
#include "xgboost/string_view.h" // for StringView
namespace xgboost::collective {
/**
* @brief A stub for NCCL to facilitate dynamic loading.
*/
class NcclStub {
#if defined(XGBOOST_USE_DLOPEN_NCCL)
void* handle_{nullptr};
@ -30,61 +34,48 @@ class NcclStub {
decltype(ncclGetErrorString)* get_error_string_{nullptr};
decltype(ncclGetVersion)* get_version_{nullptr};
public:
Result GetNcclResult(ncclResult_t code) const;
public:
explicit NcclStub(StringView path);
~NcclStub();
[[nodiscard]] ncclResult_t Allreduce(const void* sendbuff, void* recvbuff, size_t count,
[[nodiscard]] Result Allreduce(const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
cudaStream_t stream) const {
CHECK(allreduce_);
return this->allreduce_(sendbuff, recvbuff, count, datatype, op, comm, stream);
return this->GetNcclResult(allreduce_(sendbuff, recvbuff, count, datatype, op, comm, stream));
}
[[nodiscard]] ncclResult_t Broadcast(const void* sendbuff, void* recvbuff, size_t count,
[[nodiscard]] Result Broadcast(const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, int root, ncclComm_t comm,
cudaStream_t stream) const {
CHECK(broadcast_);
return this->broadcast_(sendbuff, recvbuff, count, datatype, root, comm, stream);
return this->GetNcclResult(broadcast_(sendbuff, recvbuff, count, datatype, root, comm, stream));
}
[[nodiscard]] ncclResult_t Allgather(const void* sendbuff, void* recvbuff, size_t sendcount,
[[nodiscard]] Result Allgather(const void* sendbuff, void* recvbuff, size_t sendcount,
ncclDataType_t datatype, ncclComm_t comm,
cudaStream_t stream) const {
CHECK(allgather_);
return this->allgather_(sendbuff, recvbuff, sendcount, datatype, comm, stream);
return this->GetNcclResult(allgather_(sendbuff, recvbuff, sendcount, datatype, comm, stream));
}
[[nodiscard]] ncclResult_t CommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId,
[[nodiscard]] Result CommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId,
int rank) const {
CHECK(comm_init_rank_);
return this->comm_init_rank_(comm, nranks, commId, rank);
return this->GetNcclResult(this->comm_init_rank_(comm, nranks, commId, rank));
}
[[nodiscard]] ncclResult_t CommDestroy(ncclComm_t comm) const {
CHECK(comm_destroy_);
return this->comm_destroy_(comm);
[[nodiscard]] Result CommDestroy(ncclComm_t comm) const {
return this->GetNcclResult(comm_destroy_(comm));
}
[[nodiscard]] ncclResult_t GetUniqueId(ncclUniqueId* uniqueId) const {
CHECK(get_uniqueid_);
return this->get_uniqueid_(uniqueId);
[[nodiscard]] Result GetUniqueId(ncclUniqueId* uniqueId) const {
return this->GetNcclResult(get_uniqueid_(uniqueId));
}
[[nodiscard]] ncclResult_t Send(const void* sendbuff, size_t count, ncclDataType_t datatype,
int peer, ncclComm_t comm, cudaStream_t stream) {
CHECK(send_);
return send_(sendbuff, count, datatype, peer, comm, stream);
[[nodiscard]] Result Send(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer,
ncclComm_t comm, cudaStream_t stream) {
return this->GetNcclResult(send_(sendbuff, count, datatype, peer, comm, stream));
}
[[nodiscard]] ncclResult_t Recv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer,
[[nodiscard]] Result Recv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer,
ncclComm_t comm, cudaStream_t stream) const {
CHECK(recv_);
return recv_(recvbuff, count, datatype, peer, comm, stream);
return this->GetNcclResult(recv_(recvbuff, count, datatype, peer, comm, stream));
}
[[nodiscard]] ncclResult_t GroupStart() const {
CHECK(group_start_);
return group_start_();
}
[[nodiscard]] ncclResult_t GroupEnd() const {
CHECK(group_end_);
return group_end_();
}
[[nodiscard]] Result GroupStart() const { return this->GetNcclResult(group_start_()); }
[[nodiscard]] Result GroupEnd() const { return this->GetNcclResult(group_end_()); }
[[nodiscard]] const char* GetErrorString(ncclResult_t result) const {
return get_error_string_(result);
}

View File

@ -36,10 +36,6 @@
#include "xgboost/logging.h"
#include "xgboost/span.h"
#ifdef XGBOOST_USE_NCCL
#include "nccl.h"
#endif // XGBOOST_USE_NCCL
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
#include "rmm/mr/device/per_device_resource.hpp"
#include "rmm/mr/device/thrust_allocator_adaptor.hpp"

View File

@ -25,15 +25,18 @@ TEST_F(CommTest, Channel) {
WorkerForTest worker{host, port, timeout, n_workers, i};
if (i % 2 == 0) {
auto p_chan = worker.Comm().Chan(i + 1);
p_chan->SendAll(
auto rc = Success() << [&] {
return p_chan->SendAll(
EraseType(common::Span<std::int32_t const>{&i, static_cast<std::size_t>(1)}));
auto rc = p_chan->Block();
} << [&] { return p_chan->Block(); };
ASSERT_TRUE(rc.OK()) << rc.Report();
} else {
auto p_chan = worker.Comm().Chan(i - 1);
std::int32_t r{-1};
p_chan->RecvAll(EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
auto rc = p_chan->Block();
auto rc = Success() << [&] {
return p_chan->RecvAll(
EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
} << [&] { return p_chan->Block(); };
ASSERT_TRUE(rc.OK()) << rc.Report();
ASSERT_EQ(r, i - 1);
}

View File

@ -23,7 +23,7 @@ TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) {
TEST(NcclDeviceCommunicatorSimpleTest, SystemError) {
auto stub = std::make_shared<NcclStub>(DefaultNcclName());
auto rc = GetNCCLResult(stub, ncclSystemError);
auto rc = stub->GetNcclResult(ncclSystemError);
auto msg = rc.Report();
ASSERT_TRUE(msg.find("environment variables") != std::string::npos);
}