Cleanup code for distributed training. (#9805)
* Cleanup code for distributed training. - Merge `GetNcclResult` into nccl stub. - Split up utilities from the main dask module. - Let Channel return `Result` to accommodate nccl channel. - Remove old `use_label_encoder` parameter.
This commit is contained in:
parent
e9260de3f3
commit
8fe1a2213c
@ -94,6 +94,8 @@ from xgboost.sklearn import (
|
||||
from xgboost.tracker import RabitTracker, get_host_ip
|
||||
from xgboost.training import train as worker_train
|
||||
|
||||
from .utils import get_n_threads
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import dask
|
||||
import distributed
|
||||
@ -908,6 +910,34 @@ async def _check_workers_are_alive(
|
||||
raise RuntimeError(f"Missing required workers: {missing_workers}")
|
||||
|
||||
|
||||
def _get_dmatrices(
|
||||
train_ref: dict,
|
||||
train_id: int,
|
||||
*refs: dict,
|
||||
evals_id: Sequence[int],
|
||||
evals_name: Sequence[str],
|
||||
n_threads: int,
|
||||
) -> Tuple[DMatrix, List[Tuple[DMatrix, str]]]:
|
||||
Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads)
|
||||
evals: List[Tuple[DMatrix, str]] = []
|
||||
for i, ref in enumerate(refs):
|
||||
if evals_id[i] == train_id:
|
||||
evals.append((Xy, evals_name[i]))
|
||||
continue
|
||||
if ref.get("ref", None) is not None:
|
||||
if ref["ref"] != train_id:
|
||||
raise ValueError(
|
||||
"The training DMatrix should be used as a reference to evaluation"
|
||||
" `QuantileDMatrix`."
|
||||
)
|
||||
del ref["ref"]
|
||||
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads, ref=Xy)
|
||||
else:
|
||||
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads)
|
||||
evals.append((eval_Xy, evals_name[i]))
|
||||
return Xy, evals
|
||||
|
||||
|
||||
async def _train_async(
|
||||
client: "distributed.Client",
|
||||
global_config: Dict[str, Any],
|
||||
@ -940,41 +970,20 @@ async def _train_async(
|
||||
) -> Optional[TrainReturnT]:
|
||||
worker = distributed.get_worker()
|
||||
local_param = parameters.copy()
|
||||
n_threads = 0
|
||||
# dask worker nthreads, "state" is available in 2022.6.1
|
||||
dwnt = worker.state.nthreads if hasattr(worker, "state") else worker.nthreads
|
||||
for p in ["nthread", "n_jobs"]:
|
||||
if (
|
||||
local_param.get(p, None) is not None
|
||||
and local_param.get(p, dwnt) != dwnt
|
||||
):
|
||||
LOGGER.info("Overriding `nthreads` defined in dask worker.")
|
||||
n_threads = local_param[p]
|
||||
break
|
||||
if n_threads == 0 or n_threads is None:
|
||||
n_threads = dwnt
|
||||
n_threads = get_n_threads(local_param, worker)
|
||||
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
|
||||
|
||||
local_history: TrainingCallback.EvalsLog = {}
|
||||
|
||||
with CommunicatorContext(**rabit_args), config.config_context(**global_config):
|
||||
Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads)
|
||||
evals: List[Tuple[DMatrix, str]] = []
|
||||
for i, ref in enumerate(refs):
|
||||
if evals_id[i] == train_id:
|
||||
evals.append((Xy, evals_name[i]))
|
||||
continue
|
||||
if ref.get("ref", None) is not None:
|
||||
if ref["ref"] != train_id:
|
||||
raise ValueError(
|
||||
"The training DMatrix should be used as a reference"
|
||||
" to evaluation `QuantileDMatrix`."
|
||||
Xy, evals = _get_dmatrices(
|
||||
train_ref,
|
||||
train_id,
|
||||
*refs,
|
||||
evals_id=evals_id,
|
||||
evals_name=evals_name,
|
||||
n_threads=n_threads,
|
||||
)
|
||||
del ref["ref"]
|
||||
eval_Xy = _dmatrix_from_list_of_parts(
|
||||
**ref, nthread=n_threads, ref=Xy
|
||||
)
|
||||
else:
|
||||
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads)
|
||||
evals.append((eval_Xy, evals_name[i]))
|
||||
|
||||
booster = worker_train(
|
||||
params=local_param,
|
||||
|
||||
24
python-package/xgboost/dask/utils.py
Normal file
24
python-package/xgboost/dask/utils.py
Normal file
@ -0,0 +1,24 @@
|
||||
"""Utilities for the XGBoost Dask interface."""
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
LOGGER = logging.getLogger("[xgboost.dask]")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import distributed
|
||||
|
||||
|
||||
def get_n_threads(local_param: Dict[str, Any], worker: "distributed.Worker") -> int:
|
||||
"""Get the number of threads from a worker and the user-supplied parameters."""
|
||||
# dask worker nthreads, "state" is available in 2022.6.1
|
||||
dwnt = worker.state.nthreads if hasattr(worker, "state") else worker.nthreads
|
||||
n_threads = None
|
||||
for p in ["nthread", "n_jobs"]:
|
||||
if local_param.get(p, None) is not None and local_param.get(p, dwnt) != dwnt:
|
||||
LOGGER.info("Overriding `nthreads` defined in dask worker.")
|
||||
n_threads = local_param[p]
|
||||
break
|
||||
if n_threads == 0 or n_threads is None:
|
||||
n_threads = dwnt
|
||||
return n_threads
|
||||
@ -808,7 +808,6 @@ class XGBModel(XGBModelBase):
|
||||
"kwargs",
|
||||
"missing",
|
||||
"n_estimators",
|
||||
"use_label_encoder",
|
||||
"enable_categorical",
|
||||
"early_stopping_rounds",
|
||||
"callbacks",
|
||||
|
||||
@ -138,7 +138,6 @@ _inverse_pyspark_param_alias_map = {v: k for k, v in _pyspark_param_alias_map.it
|
||||
_unsupported_xgb_params = [
|
||||
"gpu_id", # we have "device" pyspark param instead.
|
||||
"enable_categorical", # Use feature_types param to specify categorical feature instead
|
||||
"use_label_encoder",
|
||||
"n_jobs", # Do not allow user to set it, will use `spark.task.cpus` value instead.
|
||||
"nthread", # Ditto
|
||||
]
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2019-2023 by XGBoost Contributors
|
||||
* Copyright 2019-2023, XGBoost Contributors
|
||||
*/
|
||||
#include <thrust/transform.h> // for transform
|
||||
|
||||
@ -15,6 +15,9 @@
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/learner.h"
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
#include <nccl.h>
|
||||
#endif
|
||||
|
||||
namespace xgboost {
|
||||
void XGBBuildInfoDevice(Json *p_info) {
|
||||
|
||||
@ -26,18 +26,19 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
||||
}
|
||||
|
||||
for (std::int32_t r = 0; r < world; ++r) {
|
||||
auto rc = Success() << [&] {
|
||||
auto send_rank = (rank + world - r + worker_off) % world;
|
||||
auto send_off = send_rank * segment_size;
|
||||
send_off = std::min(send_off, data.size_bytes());
|
||||
auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off));
|
||||
next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
|
||||
|
||||
return next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
|
||||
} << [&] {
|
||||
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
|
||||
auto recv_off = recv_rank * segment_size;
|
||||
recv_off = std::min(recv_off, data.size_bytes());
|
||||
auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off));
|
||||
prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||
auto rc = prev_ch->Block();
|
||||
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||
} << [&] { return prev_ch->Block(); };
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
@ -78,19 +79,19 @@ namespace detail {
|
||||
auto next_ch = comm.Chan(next);
|
||||
|
||||
for (std::int32_t r = 0; r < world; ++r) {
|
||||
auto rc = Success() << [&] {
|
||||
auto send_rank = (rank + world - r) % world;
|
||||
auto send_off = offset[send_rank];
|
||||
auto send_size = sizes[send_rank];
|
||||
auto send_seg = erased_result.subspan(send_off, send_size);
|
||||
next_ch->SendAll(send_seg);
|
||||
|
||||
return next_ch->SendAll(send_seg);
|
||||
} << [&] {
|
||||
auto recv_rank = (rank + world - r - 1) % world;
|
||||
auto recv_off = offset[recv_rank];
|
||||
auto recv_size = sizes[recv_rank];
|
||||
auto recv_seg = erased_result.subspan(recv_off, recv_size);
|
||||
prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||
|
||||
auto rc = prev_ch->Block();
|
||||
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||
} << [&] { return prev_ch->Block(); };
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
@ -37,7 +37,10 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||
auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg);
|
||||
auto send_seg = data.subspan(send_off, seg_nbytes);
|
||||
|
||||
next_ch->SendAll(send_seg);
|
||||
auto rc = next_ch->SendAll(send_seg);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
// receive from ring prev
|
||||
auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg;
|
||||
@ -47,8 +50,7 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||
auto recv_seg = data.subspan(recv_off, seg_nbytes);
|
||||
auto seg = s_buf.subspan(0, recv_seg.size());
|
||||
|
||||
prev_ch->RecvAll(seg);
|
||||
auto rc = comm.Block();
|
||||
rc = std::move(rc) << [&] { return prev_ch->RecvAll(seg); } << [&] { return comm.Block(); };
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
@ -62,8 +62,8 @@ Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t
|
||||
|
||||
if (shifted_rank != 0) { // not root
|
||||
auto parent = ShiftRight(ShiftedParentRank(shifted_rank, depth), world, root);
|
||||
comm.Chan(parent)->RecvAll(data);
|
||||
auto rc = comm.Chan(parent)->Block();
|
||||
auto rc = Success() << [&] { return comm.Chan(parent)->RecvAll(data); }
|
||||
<< [&] { return comm.Chan(parent)->Block(); };
|
||||
if (!rc.OK()) {
|
||||
return Fail("broadcast failed.", std::move(rc));
|
||||
}
|
||||
@ -75,7 +75,10 @@ Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t
|
||||
auto sft_peer = shifted_rank + (1 << i);
|
||||
auto peer = ShiftRight(sft_peer, world, root);
|
||||
CHECK_NE(peer, root);
|
||||
comm.Chan(peer)->SendAll(data);
|
||||
auto rc = comm.Chan(peer)->SendAll(data);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -79,8 +79,8 @@ void RunBitwiseAllreduce(dh::CUDAStreamView stream, common::Span<std::int8_t> ou
|
||||
|
||||
// First gather data from all the workers.
|
||||
CHECK(handle);
|
||||
auto rc = GetNCCLResult(stub, stub->Allgather(data.data(), device_buffer, data.size(), ncclInt8,
|
||||
handle, pcomm->Stream()));
|
||||
auto rc =
|
||||
stub->Allgather(data.data(), device_buffer, data.size(), ncclInt8, handle, pcomm->Stream());
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
@ -140,9 +140,8 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
|
||||
return DispatchDType(type, [=](auto t) {
|
||||
using T = decltype(t);
|
||||
auto rdata = common::RestoreType<T>(data);
|
||||
auto rc = stub->Allreduce(data.data(), data.data(), rdata.size(), GetNCCLType(type),
|
||||
return stub->Allreduce(data.data(), data.data(), rdata.size(), GetNCCLType(type),
|
||||
GetNCCLRedOp(op), nccl->Handle(), nccl->Stream());
|
||||
return GetNCCLResult(stub, rc);
|
||||
});
|
||||
}
|
||||
} << [&] { return nccl->Block(); };
|
||||
@ -158,8 +157,8 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
|
||||
auto stub = nccl->Stub();
|
||||
|
||||
return Success() << [&] {
|
||||
return GetNCCLResult(stub, stub->Broadcast(data.data(), data.data(), data.size_bytes(),
|
||||
ncclInt8, root, nccl->Handle(), nccl->Stream()));
|
||||
return stub->Broadcast(data.data(), data.data(), data.size_bytes(), ncclInt8, root,
|
||||
nccl->Handle(), nccl->Stream());
|
||||
} << [&] { return nccl->Block(); };
|
||||
}
|
||||
|
||||
@ -174,8 +173,8 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
|
||||
|
||||
auto send = data.subspan(comm.Rank() * size, size);
|
||||
return Success() << [&] {
|
||||
return GetNCCLResult(stub, stub->Allgather(send.data(), data.data(), size, ncclInt8,
|
||||
nccl->Handle(), nccl->Stream()));
|
||||
return stub->Allgather(send.data(), data.data(), size, ncclInt8, nccl->Handle(),
|
||||
nccl->Stream());
|
||||
} << [&] { return nccl->Block(); };
|
||||
}
|
||||
|
||||
@ -188,19 +187,19 @@ namespace cuda_impl {
|
||||
Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes, common::Span<std::int8_t> recv) {
|
||||
auto stub = comm->Stub();
|
||||
return Success() << [&stub] { return GetNCCLResult(stub, stub->GroupStart()); } << [&] {
|
||||
return Success() << [&stub] { return stub->GroupStart(); } << [&] {
|
||||
std::size_t offset = 0;
|
||||
for (std::int32_t r = 0; r < comm->World(); ++r) {
|
||||
auto as_bytes = sizes[r];
|
||||
auto rc = stub->Broadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes,
|
||||
ncclInt8, r, comm->Handle(), dh::DefaultStream());
|
||||
if (rc != ncclSuccess) {
|
||||
return GetNCCLResult(stub, rc);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
offset += as_bytes;
|
||||
}
|
||||
return Success();
|
||||
} << [&] { return GetNCCLResult(stub, stub->GroupEnd()); };
|
||||
} << [&] { return stub->GroupEnd(); };
|
||||
}
|
||||
} // namespace cuda_impl
|
||||
|
||||
@ -217,7 +216,7 @@ Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const>
|
||||
|
||||
switch (algo) {
|
||||
case AllgatherVAlgo::kRing: {
|
||||
return Success() << [&] { return GetNCCLResult(stub, stub->GroupStart()); } << [&] {
|
||||
return Success() << [&] { return stub->GroupStart(); } << [&] {
|
||||
// get worker offset
|
||||
detail::AllgatherVOffset(sizes, recv_segments);
|
||||
// copy data
|
||||
@ -228,7 +227,7 @@ Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const>
|
||||
}
|
||||
return detail::RingAllgatherV(comm, sizes, recv_segments, recv);
|
||||
} << [&] {
|
||||
return GetNCCLResult(stub, stub->GroupEnd());
|
||||
return stub->GroupEnd();
|
||||
} << [&] { return nccl->Block(); };
|
||||
}
|
||||
case AllgatherVAlgo::kBcast: {
|
||||
|
||||
@ -26,7 +26,7 @@ Result GetUniqueId(Comm const& comm, std::shared_ptr<NcclStub> stub, std::shared
|
||||
static const int kRootRank = 0;
|
||||
ncclUniqueId id;
|
||||
if (comm.Rank() == kRootRank) {
|
||||
auto rc = GetNCCLResult(stub, stub->GetUniqueId(&id));
|
||||
auto rc = stub->GetUniqueId(&id);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
auto rc = coll->Broadcast(
|
||||
@ -99,11 +99,9 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
||||
<< "Multiple processes within communication group running on same CUDA "
|
||||
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||
|
||||
rc = std::move(rc) << [&] {
|
||||
return GetUniqueId(root, this->stub_, pimpl, &nccl_unique_id_);
|
||||
} << [&] {
|
||||
return GetNCCLResult(this->stub_, this->stub_->CommInitRank(&nccl_comm_, root.World(),
|
||||
nccl_unique_id_, root.Rank()));
|
||||
rc = std::move(rc) << [&] { return GetUniqueId(root, this->stub_, pimpl, &nccl_unique_id_); } <<
|
||||
[&] {
|
||||
return this->stub_->CommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank());
|
||||
};
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
|
||||
@ -115,7 +113,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
||||
|
||||
NCCLComm::~NCCLComm() {
|
||||
if (nccl_comm_) {
|
||||
auto rc = GetNCCLResult(stub_, stub_->CommDestroy(nccl_comm_));
|
||||
auto rc = stub_->CommDestroy(nccl_comm_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
}
|
||||
|
||||
@ -52,25 +52,6 @@ class NCCLComm : public Comm {
|
||||
}
|
||||
};
|
||||
|
||||
inline Result GetNCCLResult(std::shared_ptr<NcclStub> stub, ncclResult_t code) {
|
||||
if (code == ncclSuccess) {
|
||||
return Success();
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "NCCL failure: " << stub->GetErrorString(code) << ".";
|
||||
if (code == ncclUnhandledCudaError) {
|
||||
// nccl usually preserves the last error so we can get more details.
|
||||
auto err = cudaPeekAtLastError();
|
||||
ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n";
|
||||
} else if (code == ncclSystemError) {
|
||||
ss << " This might be caused by a network configuration issue. Please consider specifying "
|
||||
"the network interface for NCCL via environment variables listed in its reference: "
|
||||
"`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n";
|
||||
}
|
||||
return Fail(ss.str());
|
||||
}
|
||||
|
||||
class NCCLChannel : public Channel {
|
||||
std::int32_t rank_{-1};
|
||||
ncclComm_t nccl_comm_{};
|
||||
@ -86,13 +67,11 @@ class NCCLChannel : public Channel {
|
||||
Channel{comm, nullptr},
|
||||
stream_{stream} {}
|
||||
|
||||
void SendAll(std::int8_t const* ptr, std::size_t n) override {
|
||||
auto rc = GetNCCLResult(stub_, stub_->Send(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
[[nodiscard]] Result SendAll(std::int8_t const* ptr, std::size_t n) override {
|
||||
return stub_->Send(ptr, n, ncclInt8, rank_, nccl_comm_, stream_);
|
||||
}
|
||||
void RecvAll(std::int8_t* ptr, std::size_t n) override {
|
||||
auto rc = GetNCCLResult(stub_, stub_->Recv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
[[nodiscard]] Result RecvAll(std::int8_t* ptr, std::size_t n) override {
|
||||
return stub_->Recv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_);
|
||||
}
|
||||
[[nodiscard]] Result Block() override {
|
||||
auto rc = stream_.Sync(false);
|
||||
|
||||
@ -135,21 +135,25 @@ class Channel {
|
||||
explicit Channel(Comm const& comm, std::shared_ptr<TCPSocket> sock)
|
||||
: sock_{std::move(sock)}, comm_{comm} {}
|
||||
|
||||
virtual void SendAll(std::int8_t const* ptr, std::size_t n) {
|
||||
[[nodiscard]] virtual Result SendAll(std::int8_t const* ptr, std::size_t n) {
|
||||
Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast<std::int8_t*>(ptr), n, sock_.get(), 0};
|
||||
CHECK(sock_.get());
|
||||
comm_.Submit(std::move(op));
|
||||
return Success();
|
||||
}
|
||||
void SendAll(common::Span<std::int8_t const> data) {
|
||||
this->SendAll(data.data(), data.size_bytes());
|
||||
[[nodiscard]] Result SendAll(common::Span<std::int8_t const> data) {
|
||||
return this->SendAll(data.data(), data.size_bytes());
|
||||
}
|
||||
|
||||
virtual void RecvAll(std::int8_t* ptr, std::size_t n) {
|
||||
[[nodiscard]] virtual Result RecvAll(std::int8_t* ptr, std::size_t n) {
|
||||
Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0};
|
||||
CHECK(sock_.get());
|
||||
comm_.Submit(std::move(op));
|
||||
return Success();
|
||||
}
|
||||
[[nodiscard]] Result RecvAll(common::Span<std::int8_t> data) {
|
||||
return this->RecvAll(data.data(), data.size_bytes());
|
||||
}
|
||||
void RecvAll(common::Span<std::int8_t> data) { this->RecvAll(data.data(), data.size_bytes()); }
|
||||
|
||||
[[nodiscard]] auto Socket() const { return sock_; }
|
||||
[[nodiscard]] virtual Result Block() { return comm_.Block(); }
|
||||
|
||||
@ -46,8 +46,7 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sy
|
||||
|
||||
nccl_unique_id_ = GetUniqueId();
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto rc =
|
||||
GetNCCLResult(stub_, stub_->CommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
|
||||
auto rc = stub_->CommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
|
||||
@ -56,7 +55,7 @@ NcclDeviceCommunicator::~NcclDeviceCommunicator() {
|
||||
return;
|
||||
}
|
||||
if (nccl_comm_) {
|
||||
auto rc = GetNCCLResult(stub_, stub_->CommDestroy(nccl_comm_));
|
||||
auto rc = stub_->CommDestroy(nccl_comm_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||
@ -143,9 +142,8 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
|
||||
auto *device_buffer = buffer.data().get();
|
||||
|
||||
// First gather data from all the workers.
|
||||
auto rc = GetNCCLResult(
|
||||
stub_, stub_->Allgather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
|
||||
nccl_comm_, dh::DefaultStream()));
|
||||
auto rc = stub_->Allgather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
|
||||
nccl_comm_, dh::DefaultStream());
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
if (needs_sync_) {
|
||||
dh::DefaultStream().Sync();
|
||||
@ -178,9 +176,9 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co
|
||||
if (IsBitwiseOp(op)) {
|
||||
BitwiseAllReduce(send_receive_buffer, count, data_type, op);
|
||||
} else {
|
||||
auto rc = GetNCCLResult(stub_, stub_->Allreduce(send_receive_buffer, send_receive_buffer, count,
|
||||
GetNcclDataType(data_type), GetNcclRedOp(op),
|
||||
nccl_comm_, dh::DefaultStream()));
|
||||
auto rc = stub_->Allreduce(send_receive_buffer, send_receive_buffer, count,
|
||||
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
|
||||
dh::DefaultStream());
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
allreduce_bytes_ += count * GetTypeSize(data_type);
|
||||
@ -194,8 +192,8 @@ void NcclDeviceCommunicator::AllGather(void const *send_buffer, void *receive_bu
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto rc = GetNCCLResult(stub_, stub_->Allgather(send_buffer, receive_buffer, send_size, ncclInt8,
|
||||
nccl_comm_, dh::DefaultStream()));
|
||||
auto rc = stub_->Allgather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_,
|
||||
dh::DefaultStream());
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
|
||||
@ -216,19 +214,18 @@ void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_b
|
||||
receive_buffer->resize(total_bytes);
|
||||
|
||||
size_t offset = 0;
|
||||
auto rc = Success() << [&] { return GetNCCLResult(stub_, stub_->GroupStart()); } << [&] {
|
||||
auto rc = Success() << [&] { return stub_->GroupStart(); } << [&] {
|
||||
for (int32_t i = 0; i < world_size_; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
auto rc = GetNCCLResult(
|
||||
stub_, stub_->Broadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
||||
ncclChar, i, nccl_comm_, dh::DefaultStream()));
|
||||
auto rc = stub_->Broadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
||||
ncclChar, i, nccl_comm_, dh::DefaultStream());
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
offset += as_bytes;
|
||||
}
|
||||
return Success();
|
||||
} << [&] { return GetNCCLResult(stub_, stub_->GroupEnd()); };
|
||||
} << [&] { return stub_->GroupEnd(); };
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::Synchronize() {
|
||||
|
||||
@ -66,7 +66,7 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
static const int kRootRank = 0;
|
||||
ncclUniqueId id;
|
||||
if (rank_ == kRootRank) {
|
||||
auto rc = GetNCCLResult(stub_, stub_->GetUniqueId(&id));
|
||||
auto rc = stub_->GetUniqueId(&id);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
|
||||
|
||||
@ -5,8 +5,11 @@
|
||||
#include "nccl_stub.h"
|
||||
|
||||
#include <cuda.h> // for CUDA_VERSION
|
||||
#include <cuda_runtime_api.h> // for cudaPeekAtLastError
|
||||
#include <dlfcn.h> // for dlclose, dlsym, dlopen
|
||||
#include <nccl.h>
|
||||
#include <thrust/system/cuda/error.h> // for cuda_category
|
||||
#include <thrust/system_error.h> // for system_error
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <sstream> // for stringstream
|
||||
@ -16,6 +19,25 @@
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
Result NcclStub::GetNcclResult(ncclResult_t code) const {
|
||||
if (code == ncclSuccess) {
|
||||
return Success();
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "NCCL failure: " << this->GetErrorString(code) << ".";
|
||||
if (code == ncclUnhandledCudaError) {
|
||||
// nccl usually preserves the last error so we can get more details.
|
||||
auto err = cudaPeekAtLastError();
|
||||
ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n";
|
||||
} else if (code == ncclSystemError) {
|
||||
ss << " This might be caused by a network configuration issue. Please consider specifying "
|
||||
"the network interface for NCCL via environment variables listed in its reference: "
|
||||
"`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n";
|
||||
}
|
||||
return Fail(ss.str());
|
||||
}
|
||||
|
||||
NcclStub::NcclStub(StringView path) : path_{std::move(path)} {
|
||||
#if defined(XGBOOST_USE_DLOPEN_NCCL)
|
||||
CHECK(!path_.empty()) << "Empty path for NCCL.";
|
||||
|
||||
@ -8,9 +8,13 @@
|
||||
|
||||
#include <string> // for string
|
||||
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/string_view.h" // for StringView
|
||||
|
||||
namespace xgboost::collective {
|
||||
/**
|
||||
* @brief A stub for NCCL to facilitate dynamic loading.
|
||||
*/
|
||||
class NcclStub {
|
||||
#if defined(XGBOOST_USE_DLOPEN_NCCL)
|
||||
void* handle_{nullptr};
|
||||
@ -30,61 +34,48 @@ class NcclStub {
|
||||
decltype(ncclGetErrorString)* get_error_string_{nullptr};
|
||||
decltype(ncclGetVersion)* get_version_{nullptr};
|
||||
|
||||
public:
|
||||
Result GetNcclResult(ncclResult_t code) const;
|
||||
|
||||
public:
|
||||
explicit NcclStub(StringView path);
|
||||
~NcclStub();
|
||||
|
||||
[[nodiscard]] ncclResult_t Allreduce(const void* sendbuff, void* recvbuff, size_t count,
|
||||
[[nodiscard]] Result Allreduce(const void* sendbuff, void* recvbuff, size_t count,
|
||||
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
cudaStream_t stream) const {
|
||||
CHECK(allreduce_);
|
||||
return this->allreduce_(sendbuff, recvbuff, count, datatype, op, comm, stream);
|
||||
return this->GetNcclResult(allreduce_(sendbuff, recvbuff, count, datatype, op, comm, stream));
|
||||
}
|
||||
[[nodiscard]] ncclResult_t Broadcast(const void* sendbuff, void* recvbuff, size_t count,
|
||||
[[nodiscard]] Result Broadcast(const void* sendbuff, void* recvbuff, size_t count,
|
||||
ncclDataType_t datatype, int root, ncclComm_t comm,
|
||||
cudaStream_t stream) const {
|
||||
CHECK(broadcast_);
|
||||
return this->broadcast_(sendbuff, recvbuff, count, datatype, root, comm, stream);
|
||||
return this->GetNcclResult(broadcast_(sendbuff, recvbuff, count, datatype, root, comm, stream));
|
||||
}
|
||||
[[nodiscard]] ncclResult_t Allgather(const void* sendbuff, void* recvbuff, size_t sendcount,
|
||||
[[nodiscard]] Result Allgather(const void* sendbuff, void* recvbuff, size_t sendcount,
|
||||
ncclDataType_t datatype, ncclComm_t comm,
|
||||
cudaStream_t stream) const {
|
||||
CHECK(allgather_);
|
||||
return this->allgather_(sendbuff, recvbuff, sendcount, datatype, comm, stream);
|
||||
return this->GetNcclResult(allgather_(sendbuff, recvbuff, sendcount, datatype, comm, stream));
|
||||
}
|
||||
[[nodiscard]] ncclResult_t CommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId,
|
||||
[[nodiscard]] Result CommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId,
|
||||
int rank) const {
|
||||
CHECK(comm_init_rank_);
|
||||
return this->comm_init_rank_(comm, nranks, commId, rank);
|
||||
return this->GetNcclResult(this->comm_init_rank_(comm, nranks, commId, rank));
|
||||
}
|
||||
[[nodiscard]] ncclResult_t CommDestroy(ncclComm_t comm) const {
|
||||
CHECK(comm_destroy_);
|
||||
return this->comm_destroy_(comm);
|
||||
[[nodiscard]] Result CommDestroy(ncclComm_t comm) const {
|
||||
return this->GetNcclResult(comm_destroy_(comm));
|
||||
}
|
||||
|
||||
[[nodiscard]] ncclResult_t GetUniqueId(ncclUniqueId* uniqueId) const {
|
||||
CHECK(get_uniqueid_);
|
||||
return this->get_uniqueid_(uniqueId);
|
||||
[[nodiscard]] Result GetUniqueId(ncclUniqueId* uniqueId) const {
|
||||
return this->GetNcclResult(get_uniqueid_(uniqueId));
|
||||
}
|
||||
[[nodiscard]] ncclResult_t Send(const void* sendbuff, size_t count, ncclDataType_t datatype,
|
||||
int peer, ncclComm_t comm, cudaStream_t stream) {
|
||||
CHECK(send_);
|
||||
return send_(sendbuff, count, datatype, peer, comm, stream);
|
||||
[[nodiscard]] Result Send(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer,
|
||||
ncclComm_t comm, cudaStream_t stream) {
|
||||
return this->GetNcclResult(send_(sendbuff, count, datatype, peer, comm, stream));
|
||||
}
|
||||
[[nodiscard]] ncclResult_t Recv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer,
|
||||
[[nodiscard]] Result Recv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer,
|
||||
ncclComm_t comm, cudaStream_t stream) const {
|
||||
CHECK(recv_);
|
||||
return recv_(recvbuff, count, datatype, peer, comm, stream);
|
||||
return this->GetNcclResult(recv_(recvbuff, count, datatype, peer, comm, stream));
|
||||
}
|
||||
[[nodiscard]] ncclResult_t GroupStart() const {
|
||||
CHECK(group_start_);
|
||||
return group_start_();
|
||||
}
|
||||
[[nodiscard]] ncclResult_t GroupEnd() const {
|
||||
CHECK(group_end_);
|
||||
return group_end_();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result GroupStart() const { return this->GetNcclResult(group_start_()); }
|
||||
[[nodiscard]] Result GroupEnd() const { return this->GetNcclResult(group_end_()); }
|
||||
[[nodiscard]] const char* GetErrorString(ncclResult_t result) const {
|
||||
return get_error_string_(result);
|
||||
}
|
||||
|
||||
@ -36,10 +36,6 @@
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/span.h"
|
||||
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
#include "nccl.h"
|
||||
#endif // XGBOOST_USE_NCCL
|
||||
|
||||
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
#include "rmm/mr/device/per_device_resource.hpp"
|
||||
#include "rmm/mr/device/thrust_allocator_adaptor.hpp"
|
||||
|
||||
@ -25,15 +25,18 @@ TEST_F(CommTest, Channel) {
|
||||
WorkerForTest worker{host, port, timeout, n_workers, i};
|
||||
if (i % 2 == 0) {
|
||||
auto p_chan = worker.Comm().Chan(i + 1);
|
||||
p_chan->SendAll(
|
||||
auto rc = Success() << [&] {
|
||||
return p_chan->SendAll(
|
||||
EraseType(common::Span<std::int32_t const>{&i, static_cast<std::size_t>(1)}));
|
||||
auto rc = p_chan->Block();
|
||||
} << [&] { return p_chan->Block(); };
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
} else {
|
||||
auto p_chan = worker.Comm().Chan(i - 1);
|
||||
std::int32_t r{-1};
|
||||
p_chan->RecvAll(EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
|
||||
auto rc = p_chan->Block();
|
||||
auto rc = Success() << [&] {
|
||||
return p_chan->RecvAll(
|
||||
EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
|
||||
} << [&] { return p_chan->Block(); };
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
ASSERT_EQ(r, i - 1);
|
||||
}
|
||||
|
||||
@ -23,7 +23,7 @@ TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) {
|
||||
|
||||
TEST(NcclDeviceCommunicatorSimpleTest, SystemError) {
|
||||
auto stub = std::make_shared<NcclStub>(DefaultNcclName());
|
||||
auto rc = GetNCCLResult(stub, ncclSystemError);
|
||||
auto rc = stub->GetNcclResult(ncclSystemError);
|
||||
auto msg = rc.Report();
|
||||
ASSERT_TRUE(msg.find("environment variables") != std::string::npos);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user