diff --git a/python-package/xgboost/dask/__init__.py b/python-package/xgboost/dask/__init__.py index a58c0f225..068b1e6ea 100644 --- a/python-package/xgboost/dask/__init__.py +++ b/python-package/xgboost/dask/__init__.py @@ -94,6 +94,8 @@ from xgboost.sklearn import ( from xgboost.tracker import RabitTracker, get_host_ip from xgboost.training import train as worker_train +from .utils import get_n_threads + if TYPE_CHECKING: import dask import distributed @@ -908,6 +910,34 @@ async def _check_workers_are_alive( raise RuntimeError(f"Missing required workers: {missing_workers}") +def _get_dmatrices( + train_ref: dict, + train_id: int, + *refs: dict, + evals_id: Sequence[int], + evals_name: Sequence[str], + n_threads: int, +) -> Tuple[DMatrix, List[Tuple[DMatrix, str]]]: + Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads) + evals: List[Tuple[DMatrix, str]] = [] + for i, ref in enumerate(refs): + if evals_id[i] == train_id: + evals.append((Xy, evals_name[i])) + continue + if ref.get("ref", None) is not None: + if ref["ref"] != train_id: + raise ValueError( + "The training DMatrix should be used as a reference to evaluation" + " `QuantileDMatrix`." + ) + del ref["ref"] + eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads, ref=Xy) + else: + eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads) + evals.append((eval_Xy, evals_name[i])) + return Xy, evals + + async def _train_async( client: "distributed.Client", global_config: Dict[str, Any], @@ -940,41 +970,20 @@ async def _train_async( ) -> Optional[TrainReturnT]: worker = distributed.get_worker() local_param = parameters.copy() - n_threads = 0 - # dask worker nthreads, "state" is available in 2022.6.1 - dwnt = worker.state.nthreads if hasattr(worker, "state") else worker.nthreads - for p in ["nthread", "n_jobs"]: - if ( - local_param.get(p, None) is not None - and local_param.get(p, dwnt) != dwnt - ): - LOGGER.info("Overriding `nthreads` defined in dask worker.") - n_threads = local_param[p] - break - if n_threads == 0 or n_threads is None: - n_threads = dwnt + n_threads = get_n_threads(local_param, worker) local_param.update({"nthread": n_threads, "n_jobs": n_threads}) + local_history: TrainingCallback.EvalsLog = {} + with CommunicatorContext(**rabit_args), config.config_context(**global_config): - Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads) - evals: List[Tuple[DMatrix, str]] = [] - for i, ref in enumerate(refs): - if evals_id[i] == train_id: - evals.append((Xy, evals_name[i])) - continue - if ref.get("ref", None) is not None: - if ref["ref"] != train_id: - raise ValueError( - "The training DMatrix should be used as a reference" - " to evaluation `QuantileDMatrix`." - ) - del ref["ref"] - eval_Xy = _dmatrix_from_list_of_parts( - **ref, nthread=n_threads, ref=Xy - ) - else: - eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads) - evals.append((eval_Xy, evals_name[i])) + Xy, evals = _get_dmatrices( + train_ref, + train_id, + *refs, + evals_id=evals_id, + evals_name=evals_name, + n_threads=n_threads, + ) booster = worker_train( params=local_param, diff --git a/python-package/xgboost/dask/utils.py b/python-package/xgboost/dask/utils.py new file mode 100644 index 000000000..98e6029b5 --- /dev/null +++ b/python-package/xgboost/dask/utils.py @@ -0,0 +1,24 @@ +"""Utilities for the XGBoost Dask interface.""" +import logging +from typing import TYPE_CHECKING, Any, Dict + +LOGGER = logging.getLogger("[xgboost.dask]") + + +if TYPE_CHECKING: + import distributed + + +def get_n_threads(local_param: Dict[str, Any], worker: "distributed.Worker") -> int: + """Get the number of threads from a worker and the user-supplied parameters.""" + # dask worker nthreads, "state" is available in 2022.6.1 + dwnt = worker.state.nthreads if hasattr(worker, "state") else worker.nthreads + n_threads = None + for p in ["nthread", "n_jobs"]: + if local_param.get(p, None) is not None and local_param.get(p, dwnt) != dwnt: + LOGGER.info("Overriding `nthreads` defined in dask worker.") + n_threads = local_param[p] + break + if n_threads == 0 or n_threads is None: + n_threads = dwnt + return n_threads diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 3906973a8..ea309bd94 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -808,7 +808,6 @@ class XGBModel(XGBModelBase): "kwargs", "missing", "n_estimators", - "use_label_encoder", "enable_categorical", "early_stopping_rounds", "callbacks", diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index aa8c5b998..7ac01ff07 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -138,7 +138,6 @@ _inverse_pyspark_param_alias_map = {v: k for k, v in _pyspark_param_alias_map.it _unsupported_xgb_params = [ "gpu_id", # we have "device" pyspark param instead. "enable_categorical", # Use feature_types param to specify categorical feature instead - "use_label_encoder", "n_jobs", # Do not allow user to set it, will use `spark.task.cpus` value instead. "nthread", # Ditto ] diff --git a/src/c_api/c_api.cu b/src/c_api/c_api.cu index 4ace8b7cc..47868f466 100644 --- a/src/c_api/c_api.cu +++ b/src/c_api/c_api.cu @@ -1,5 +1,5 @@ /** - * Copyright 2019-2023 by XGBoost Contributors + * Copyright 2019-2023, XGBoost Contributors */ #include // for transform @@ -15,6 +15,9 @@ #include "xgboost/data.h" #include "xgboost/json.h" #include "xgboost/learner.h" +#if defined(XGBOOST_USE_NCCL) +#include +#endif namespace xgboost { void XGBBuildInfoDevice(Json *p_info) { diff --git a/src/collective/allgather.cc b/src/collective/allgather.cc index fa369a9da..148cb6cd2 100644 --- a/src/collective/allgather.cc +++ b/src/collective/allgather.cc @@ -26,18 +26,19 @@ Result RingAllgather(Comm const& comm, common::Span data, std::size } for (std::int32_t r = 0; r < world; ++r) { - auto send_rank = (rank + world - r + worker_off) % world; - auto send_off = send_rank * segment_size; - send_off = std::min(send_off, data.size_bytes()); - auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off)); - next_ch->SendAll(send_seg.data(), send_seg.size_bytes()); - - auto recv_rank = (rank + world - r - 1 + worker_off) % world; - auto recv_off = recv_rank * segment_size; - recv_off = std::min(recv_off, data.size_bytes()); - auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off)); - prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes()); - auto rc = prev_ch->Block(); + auto rc = Success() << [&] { + auto send_rank = (rank + world - r + worker_off) % world; + auto send_off = send_rank * segment_size; + send_off = std::min(send_off, data.size_bytes()); + auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off)); + return next_ch->SendAll(send_seg.data(), send_seg.size_bytes()); + } << [&] { + auto recv_rank = (rank + world - r - 1 + worker_off) % world; + auto recv_off = recv_rank * segment_size; + recv_off = std::min(recv_off, data.size_bytes()); + auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off)); + return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes()); + } << [&] { return prev_ch->Block(); }; if (!rc.OK()) { return rc; } @@ -78,19 +79,19 @@ namespace detail { auto next_ch = comm.Chan(next); for (std::int32_t r = 0; r < world; ++r) { - auto send_rank = (rank + world - r) % world; - auto send_off = offset[send_rank]; - auto send_size = sizes[send_rank]; - auto send_seg = erased_result.subspan(send_off, send_size); - next_ch->SendAll(send_seg); - - auto recv_rank = (rank + world - r - 1) % world; - auto recv_off = offset[recv_rank]; - auto recv_size = sizes[recv_rank]; - auto recv_seg = erased_result.subspan(recv_off, recv_size); - prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes()); - - auto rc = prev_ch->Block(); + auto rc = Success() << [&] { + auto send_rank = (rank + world - r) % world; + auto send_off = offset[send_rank]; + auto send_size = sizes[send_rank]; + auto send_seg = erased_result.subspan(send_off, send_size); + return next_ch->SendAll(send_seg); + } << [&] { + auto recv_rank = (rank + world - r - 1) % world; + auto recv_off = offset[recv_rank]; + auto recv_size = sizes[recv_rank]; + auto recv_seg = erased_result.subspan(recv_off, recv_size); + return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes()); + } << [&] { return prev_ch->Block(); }; if (!rc.OK()) { return rc; } diff --git a/src/collective/allreduce.cc b/src/collective/allreduce.cc index f95a9a9f1..93b76355f 100644 --- a/src/collective/allreduce.cc +++ b/src/collective/allreduce.cc @@ -37,7 +37,10 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span data, auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg); auto send_seg = data.subspan(send_off, seg_nbytes); - next_ch->SendAll(send_seg); + auto rc = next_ch->SendAll(send_seg); + if (!rc.OK()) { + return rc; + } // receive from ring prev auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg; @@ -47,8 +50,7 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span data, auto recv_seg = data.subspan(recv_off, seg_nbytes); auto seg = s_buf.subspan(0, recv_seg.size()); - prev_ch->RecvAll(seg); - auto rc = comm.Block(); + rc = std::move(rc) << [&] { return prev_ch->RecvAll(seg); } << [&] { return comm.Block(); }; if (!rc.OK()) { return rc; } diff --git a/src/collective/broadcast.cc b/src/collective/broadcast.cc index 660bb9130..e1ef60f86 100644 --- a/src/collective/broadcast.cc +++ b/src/collective/broadcast.cc @@ -62,8 +62,8 @@ Result Broadcast(Comm const& comm, common::Span data, std::int32_t if (shifted_rank != 0) { // not root auto parent = ShiftRight(ShiftedParentRank(shifted_rank, depth), world, root); - comm.Chan(parent)->RecvAll(data); - auto rc = comm.Chan(parent)->Block(); + auto rc = Success() << [&] { return comm.Chan(parent)->RecvAll(data); } + << [&] { return comm.Chan(parent)->Block(); }; if (!rc.OK()) { return Fail("broadcast failed.", std::move(rc)); } @@ -75,7 +75,10 @@ Result Broadcast(Comm const& comm, common::Span data, std::int32_t auto sft_peer = shifted_rank + (1 << i); auto peer = ShiftRight(sft_peer, world, root); CHECK_NE(peer, root); - comm.Chan(peer)->SendAll(data); + auto rc = comm.Chan(peer)->SendAll(data); + if (!rc.OK()) { + return rc; + } } } diff --git a/src/collective/coll.cu b/src/collective/coll.cu index 60072b6a5..d1b66a8ce 100644 --- a/src/collective/coll.cu +++ b/src/collective/coll.cu @@ -79,8 +79,8 @@ void RunBitwiseAllreduce(dh::CUDAStreamView stream, common::Span ou // First gather data from all the workers. CHECK(handle); - auto rc = GetNCCLResult(stub, stub->Allgather(data.data(), device_buffer, data.size(), ncclInt8, - handle, pcomm->Stream())); + auto rc = + stub->Allgather(data.data(), device_buffer, data.size(), ncclInt8, handle, pcomm->Stream()); if (!rc.OK()) { return rc; } @@ -140,9 +140,8 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) { return DispatchDType(type, [=](auto t) { using T = decltype(t); auto rdata = common::RestoreType(data); - auto rc = stub->Allreduce(data.data(), data.data(), rdata.size(), GetNCCLType(type), - GetNCCLRedOp(op), nccl->Handle(), nccl->Stream()); - return GetNCCLResult(stub, rc); + return stub->Allreduce(data.data(), data.data(), rdata.size(), GetNCCLType(type), + GetNCCLRedOp(op), nccl->Handle(), nccl->Stream()); }); } } << [&] { return nccl->Block(); }; @@ -158,8 +157,8 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) { auto stub = nccl->Stub(); return Success() << [&] { - return GetNCCLResult(stub, stub->Broadcast(data.data(), data.data(), data.size_bytes(), - ncclInt8, root, nccl->Handle(), nccl->Stream())); + return stub->Broadcast(data.data(), data.data(), data.size_bytes(), ncclInt8, root, + nccl->Handle(), nccl->Stream()); } << [&] { return nccl->Block(); }; } @@ -174,8 +173,8 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) { auto send = data.subspan(comm.Rank() * size, size); return Success() << [&] { - return GetNCCLResult(stub, stub->Allgather(send.data(), data.data(), size, ncclInt8, - nccl->Handle(), nccl->Stream())); + return stub->Allgather(send.data(), data.data(), size, ncclInt8, nccl->Handle(), + nccl->Stream()); } << [&] { return nccl->Block(); }; } @@ -188,19 +187,19 @@ namespace cuda_impl { Result BroadcastAllgatherV(NCCLComm const* comm, common::Span data, common::Span sizes, common::Span recv) { auto stub = comm->Stub(); - return Success() << [&stub] { return GetNCCLResult(stub, stub->GroupStart()); } << [&] { + return Success() << [&stub] { return stub->GroupStart(); } << [&] { std::size_t offset = 0; for (std::int32_t r = 0; r < comm->World(); ++r) { auto as_bytes = sizes[r]; auto rc = stub->Broadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes, ncclInt8, r, comm->Handle(), dh::DefaultStream()); - if (rc != ncclSuccess) { - return GetNCCLResult(stub, rc); + if (!rc.OK()) { + return rc; } offset += as_bytes; } return Success(); - } << [&] { return GetNCCLResult(stub, stub->GroupEnd()); }; + } << [&] { return stub->GroupEnd(); }; } } // namespace cuda_impl @@ -217,7 +216,7 @@ Result BroadcastAllgatherV(NCCLComm const* comm, common::Span switch (algo) { case AllgatherVAlgo::kRing: { - return Success() << [&] { return GetNCCLResult(stub, stub->GroupStart()); } << [&] { + return Success() << [&] { return stub->GroupStart(); } << [&] { // get worker offset detail::AllgatherVOffset(sizes, recv_segments); // copy data @@ -228,7 +227,7 @@ Result BroadcastAllgatherV(NCCLComm const* comm, common::Span } return detail::RingAllgatherV(comm, sizes, recv_segments, recv); } << [&] { - return GetNCCLResult(stub, stub->GroupEnd()); + return stub->GroupEnd(); } << [&] { return nccl->Block(); }; } case AllgatherVAlgo::kBcast: { diff --git a/src/collective/comm.cu b/src/collective/comm.cu index cc67def0a..56681253c 100644 --- a/src/collective/comm.cu +++ b/src/collective/comm.cu @@ -26,7 +26,7 @@ Result GetUniqueId(Comm const& comm, std::shared_ptr stub, std::shared static const int kRootRank = 0; ncclUniqueId id; if (comm.Rank() == kRootRank) { - auto rc = GetNCCLResult(stub, stub->GetUniqueId(&id)); + auto rc = stub->GetUniqueId(&id); CHECK(rc.OK()) << rc.Report(); } auto rc = coll->Broadcast( @@ -99,12 +99,10 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p << "Multiple processes within communication group running on same CUDA " << "device is not supported. " << PrintUUID(s_this_uuid) << "\n"; - rc = std::move(rc) << [&] { - return GetUniqueId(root, this->stub_, pimpl, &nccl_unique_id_); - } << [&] { - return GetNCCLResult(this->stub_, this->stub_->CommInitRank(&nccl_comm_, root.World(), - nccl_unique_id_, root.Rank())); - }; + rc = std::move(rc) << [&] { return GetUniqueId(root, this->stub_, pimpl, &nccl_unique_id_); } << + [&] { + return this->stub_->CommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank()); + }; CHECK(rc.OK()) << rc.Report(); for (std::int32_t r = 0; r < root.World(); ++r) { @@ -115,7 +113,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p NCCLComm::~NCCLComm() { if (nccl_comm_) { - auto rc = GetNCCLResult(stub_, stub_->CommDestroy(nccl_comm_)); + auto rc = stub_->CommDestroy(nccl_comm_); CHECK(rc.OK()) << rc.Report(); } } diff --git a/src/collective/comm.cuh b/src/collective/comm.cuh index ef537b5a9..a818d95f8 100644 --- a/src/collective/comm.cuh +++ b/src/collective/comm.cuh @@ -52,25 +52,6 @@ class NCCLComm : public Comm { } }; -inline Result GetNCCLResult(std::shared_ptr stub, ncclResult_t code) { - if (code == ncclSuccess) { - return Success(); - } - - std::stringstream ss; - ss << "NCCL failure: " << stub->GetErrorString(code) << "."; - if (code == ncclUnhandledCudaError) { - // nccl usually preserves the last error so we can get more details. - auto err = cudaPeekAtLastError(); - ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n"; - } else if (code == ncclSystemError) { - ss << " This might be caused by a network configuration issue. Please consider specifying " - "the network interface for NCCL via environment variables listed in its reference: " - "`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n"; - } - return Fail(ss.str()); -} - class NCCLChannel : public Channel { std::int32_t rank_{-1}; ncclComm_t nccl_comm_{}; @@ -86,13 +67,11 @@ class NCCLChannel : public Channel { Channel{comm, nullptr}, stream_{stream} {} - void SendAll(std::int8_t const* ptr, std::size_t n) override { - auto rc = GetNCCLResult(stub_, stub_->Send(ptr, n, ncclInt8, rank_, nccl_comm_, stream_)); - CHECK(rc.OK()) << rc.Report(); + [[nodiscard]] Result SendAll(std::int8_t const* ptr, std::size_t n) override { + return stub_->Send(ptr, n, ncclInt8, rank_, nccl_comm_, stream_); } - void RecvAll(std::int8_t* ptr, std::size_t n) override { - auto rc = GetNCCLResult(stub_, stub_->Recv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_)); - CHECK(rc.OK()) << rc.Report(); + [[nodiscard]] Result RecvAll(std::int8_t* ptr, std::size_t n) override { + return stub_->Recv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_); } [[nodiscard]] Result Block() override { auto rc = stream_.Sync(false); diff --git a/src/collective/comm.h b/src/collective/comm.h index b2f519e3d..82aa2c45e 100644 --- a/src/collective/comm.h +++ b/src/collective/comm.h @@ -135,21 +135,25 @@ class Channel { explicit Channel(Comm const& comm, std::shared_ptr sock) : sock_{std::move(sock)}, comm_{comm} {} - virtual void SendAll(std::int8_t const* ptr, std::size_t n) { + [[nodiscard]] virtual Result SendAll(std::int8_t const* ptr, std::size_t n) { Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast(ptr), n, sock_.get(), 0}; CHECK(sock_.get()); comm_.Submit(std::move(op)); + return Success(); } - void SendAll(common::Span data) { - this->SendAll(data.data(), data.size_bytes()); + [[nodiscard]] Result SendAll(common::Span data) { + return this->SendAll(data.data(), data.size_bytes()); } - virtual void RecvAll(std::int8_t* ptr, std::size_t n) { + [[nodiscard]] virtual Result RecvAll(std::int8_t* ptr, std::size_t n) { Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0}; CHECK(sock_.get()); comm_.Submit(std::move(op)); + return Success(); + } + [[nodiscard]] Result RecvAll(common::Span data) { + return this->RecvAll(data.data(), data.size_bytes()); } - void RecvAll(common::Span data) { this->RecvAll(data.data(), data.size_bytes()); } [[nodiscard]] auto Socket() const { return sock_; } [[nodiscard]] virtual Result Block() { return comm_.Block(); } diff --git a/src/collective/nccl_device_communicator.cu b/src/collective/nccl_device_communicator.cu index 25b198bde..31c2d394d 100644 --- a/src/collective/nccl_device_communicator.cu +++ b/src/collective/nccl_device_communicator.cu @@ -46,8 +46,7 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sy nccl_unique_id_ = GetUniqueId(); dh::safe_cuda(cudaSetDevice(device_ordinal_)); - auto rc = - GetNCCLResult(stub_, stub_->CommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_)); + auto rc = stub_->CommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_); CHECK(rc.OK()) << rc.Report(); } @@ -56,7 +55,7 @@ NcclDeviceCommunicator::~NcclDeviceCommunicator() { return; } if (nccl_comm_) { - auto rc = GetNCCLResult(stub_, stub_->CommDestroy(nccl_comm_)); + auto rc = stub_->CommDestroy(nccl_comm_); CHECK(rc.OK()) << rc.Report(); } if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) { @@ -143,9 +142,8 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si auto *device_buffer = buffer.data().get(); // First gather data from all the workers. - auto rc = GetNCCLResult( - stub_, stub_->Allgather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type), - nccl_comm_, dh::DefaultStream())); + auto rc = stub_->Allgather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type), + nccl_comm_, dh::DefaultStream()); CHECK(rc.OK()) << rc.Report(); if (needs_sync_) { dh::DefaultStream().Sync(); @@ -178,9 +176,9 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co if (IsBitwiseOp(op)) { BitwiseAllReduce(send_receive_buffer, count, data_type, op); } else { - auto rc = GetNCCLResult(stub_, stub_->Allreduce(send_receive_buffer, send_receive_buffer, count, - GetNcclDataType(data_type), GetNcclRedOp(op), - nccl_comm_, dh::DefaultStream())); + auto rc = stub_->Allreduce(send_receive_buffer, send_receive_buffer, count, + GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_, + dh::DefaultStream()); CHECK(rc.OK()) << rc.Report(); } allreduce_bytes_ += count * GetTypeSize(data_type); @@ -194,8 +192,8 @@ void NcclDeviceCommunicator::AllGather(void const *send_buffer, void *receive_bu } dh::safe_cuda(cudaSetDevice(device_ordinal_)); - auto rc = GetNCCLResult(stub_, stub_->Allgather(send_buffer, receive_buffer, send_size, ncclInt8, - nccl_comm_, dh::DefaultStream())); + auto rc = stub_->Allgather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_, + dh::DefaultStream()); CHECK(rc.OK()) << rc.Report(); } @@ -216,19 +214,18 @@ void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_b receive_buffer->resize(total_bytes); size_t offset = 0; - auto rc = Success() << [&] { return GetNCCLResult(stub_, stub_->GroupStart()); } << [&] { + auto rc = Success() << [&] { return stub_->GroupStart(); } << [&] { for (int32_t i = 0; i < world_size_; ++i) { size_t as_bytes = segments->at(i); - auto rc = GetNCCLResult( - stub_, stub_->Broadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes, - ncclChar, i, nccl_comm_, dh::DefaultStream())); + auto rc = stub_->Broadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes, + ncclChar, i, nccl_comm_, dh::DefaultStream()); if (!rc.OK()) { return rc; } offset += as_bytes; } return Success(); - } << [&] { return GetNCCLResult(stub_, stub_->GroupEnd()); }; + } << [&] { return stub_->GroupEnd(); }; } void NcclDeviceCommunicator::Synchronize() { diff --git a/src/collective/nccl_device_communicator.cuh b/src/collective/nccl_device_communicator.cuh index a194b4ef2..ef431b571 100644 --- a/src/collective/nccl_device_communicator.cuh +++ b/src/collective/nccl_device_communicator.cuh @@ -66,7 +66,7 @@ class NcclDeviceCommunicator : public DeviceCommunicator { static const int kRootRank = 0; ncclUniqueId id; if (rank_ == kRootRank) { - auto rc = GetNCCLResult(stub_, stub_->GetUniqueId(&id)); + auto rc = stub_->GetUniqueId(&id); CHECK(rc.OK()) << rc.Report(); } Broadcast(static_cast(&id), sizeof(ncclUniqueId), static_cast(kRootRank)); diff --git a/src/collective/nccl_stub.cc b/src/collective/nccl_stub.cc index f4705a46e..fea3f2755 100644 --- a/src/collective/nccl_stub.cc +++ b/src/collective/nccl_stub.cc @@ -4,9 +4,12 @@ #if defined(XGBOOST_USE_NCCL) #include "nccl_stub.h" -#include // for CUDA_VERSION -#include // for dlclose, dlsym, dlopen +#include // for CUDA_VERSION +#include // for cudaPeekAtLastError +#include // for dlclose, dlsym, dlopen #include +#include // for cuda_category +#include // for system_error #include // for int32_t #include // for stringstream @@ -16,6 +19,25 @@ #include "xgboost/logging.h" namespace xgboost::collective { +Result NcclStub::GetNcclResult(ncclResult_t code) const { + if (code == ncclSuccess) { + return Success(); + } + + std::stringstream ss; + ss << "NCCL failure: " << this->GetErrorString(code) << "."; + if (code == ncclUnhandledCudaError) { + // nccl usually preserves the last error so we can get more details. + auto err = cudaPeekAtLastError(); + ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n"; + } else if (code == ncclSystemError) { + ss << " This might be caused by a network configuration issue. Please consider specifying " + "the network interface for NCCL via environment variables listed in its reference: " + "`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n"; + } + return Fail(ss.str()); +} + NcclStub::NcclStub(StringView path) : path_{std::move(path)} { #if defined(XGBOOST_USE_DLOPEN_NCCL) CHECK(!path_.empty()) << "Empty path for NCCL."; diff --git a/src/collective/nccl_stub.h b/src/collective/nccl_stub.h index a003a6f22..5281b736d 100644 --- a/src/collective/nccl_stub.h +++ b/src/collective/nccl_stub.h @@ -8,9 +8,13 @@ #include // for string -#include "xgboost/string_view.h" // for StringView +#include "xgboost/collective/result.h" // for Result +#include "xgboost/string_view.h" // for StringView namespace xgboost::collective { +/** + * @brief A stub for NCCL to facilitate dynamic loading. + */ class NcclStub { #if defined(XGBOOST_USE_DLOPEN_NCCL) void* handle_{nullptr}; @@ -30,61 +34,48 @@ class NcclStub { decltype(ncclGetErrorString)* get_error_string_{nullptr}; decltype(ncclGetVersion)* get_version_{nullptr}; + public: + Result GetNcclResult(ncclResult_t code) const; + public: explicit NcclStub(StringView path); ~NcclStub(); - [[nodiscard]] ncclResult_t Allreduce(const void* sendbuff, void* recvbuff, size_t count, - ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, - cudaStream_t stream) const { - CHECK(allreduce_); - return this->allreduce_(sendbuff, recvbuff, count, datatype, op, comm, stream); + [[nodiscard]] Result Allreduce(const void* sendbuff, void* recvbuff, size_t count, + ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + cudaStream_t stream) const { + return this->GetNcclResult(allreduce_(sendbuff, recvbuff, count, datatype, op, comm, stream)); } - [[nodiscard]] ncclResult_t Broadcast(const void* sendbuff, void* recvbuff, size_t count, - ncclDataType_t datatype, int root, ncclComm_t comm, - cudaStream_t stream) const { - CHECK(broadcast_); - return this->broadcast_(sendbuff, recvbuff, count, datatype, root, comm, stream); + [[nodiscard]] Result Broadcast(const void* sendbuff, void* recvbuff, size_t count, + ncclDataType_t datatype, int root, ncclComm_t comm, + cudaStream_t stream) const { + return this->GetNcclResult(broadcast_(sendbuff, recvbuff, count, datatype, root, comm, stream)); } - [[nodiscard]] ncclResult_t Allgather(const void* sendbuff, void* recvbuff, size_t sendcount, - ncclDataType_t datatype, ncclComm_t comm, - cudaStream_t stream) const { - CHECK(allgather_); - return this->allgather_(sendbuff, recvbuff, sendcount, datatype, comm, stream); + [[nodiscard]] Result Allgather(const void* sendbuff, void* recvbuff, size_t sendcount, + ncclDataType_t datatype, ncclComm_t comm, + cudaStream_t stream) const { + return this->GetNcclResult(allgather_(sendbuff, recvbuff, sendcount, datatype, comm, stream)); } - [[nodiscard]] ncclResult_t CommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, - int rank) const { - CHECK(comm_init_rank_); - return this->comm_init_rank_(comm, nranks, commId, rank); + [[nodiscard]] Result CommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, + int rank) const { + return this->GetNcclResult(this->comm_init_rank_(comm, nranks, commId, rank)); } - [[nodiscard]] ncclResult_t CommDestroy(ncclComm_t comm) const { - CHECK(comm_destroy_); - return this->comm_destroy_(comm); + [[nodiscard]] Result CommDestroy(ncclComm_t comm) const { + return this->GetNcclResult(comm_destroy_(comm)); } - - [[nodiscard]] ncclResult_t GetUniqueId(ncclUniqueId* uniqueId) const { - CHECK(get_uniqueid_); - return this->get_uniqueid_(uniqueId); + [[nodiscard]] Result GetUniqueId(ncclUniqueId* uniqueId) const { + return this->GetNcclResult(get_uniqueid_(uniqueId)); } - [[nodiscard]] ncclResult_t Send(const void* sendbuff, size_t count, ncclDataType_t datatype, - int peer, ncclComm_t comm, cudaStream_t stream) { - CHECK(send_); - return send_(sendbuff, count, datatype, peer, comm, stream); + [[nodiscard]] Result Send(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer, + ncclComm_t comm, cudaStream_t stream) { + return this->GetNcclResult(send_(sendbuff, count, datatype, peer, comm, stream)); } - [[nodiscard]] ncclResult_t Recv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer, - ncclComm_t comm, cudaStream_t stream) const { - CHECK(recv_); - return recv_(recvbuff, count, datatype, peer, comm, stream); + [[nodiscard]] Result Recv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer, + ncclComm_t comm, cudaStream_t stream) const { + return this->GetNcclResult(recv_(recvbuff, count, datatype, peer, comm, stream)); } - [[nodiscard]] ncclResult_t GroupStart() const { - CHECK(group_start_); - return group_start_(); - } - [[nodiscard]] ncclResult_t GroupEnd() const { - CHECK(group_end_); - return group_end_(); - } - + [[nodiscard]] Result GroupStart() const { return this->GetNcclResult(group_start_()); } + [[nodiscard]] Result GroupEnd() const { return this->GetNcclResult(group_end_()); } [[nodiscard]] const char* GetErrorString(ncclResult_t result) const { return get_error_string_(result); } diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 89ec42f2b..ffe61800e 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -36,10 +36,6 @@ #include "xgboost/logging.h" #include "xgboost/span.h" -#ifdef XGBOOST_USE_NCCL -#include "nccl.h" -#endif // XGBOOST_USE_NCCL - #if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 #include "rmm/mr/device/per_device_resource.hpp" #include "rmm/mr/device/thrust_allocator_adaptor.hpp" diff --git a/tests/cpp/collective/test_comm.cc b/tests/cpp/collective/test_comm.cc index 52fec7b5d..8e69b2f8e 100644 --- a/tests/cpp/collective/test_comm.cc +++ b/tests/cpp/collective/test_comm.cc @@ -25,15 +25,18 @@ TEST_F(CommTest, Channel) { WorkerForTest worker{host, port, timeout, n_workers, i}; if (i % 2 == 0) { auto p_chan = worker.Comm().Chan(i + 1); - p_chan->SendAll( - EraseType(common::Span{&i, static_cast(1)})); - auto rc = p_chan->Block(); + auto rc = Success() << [&] { + return p_chan->SendAll( + EraseType(common::Span{&i, static_cast(1)})); + } << [&] { return p_chan->Block(); }; ASSERT_TRUE(rc.OK()) << rc.Report(); } else { auto p_chan = worker.Comm().Chan(i - 1); std::int32_t r{-1}; - p_chan->RecvAll(EraseType(common::Span{&r, static_cast(1)})); - auto rc = p_chan->Block(); + auto rc = Success() << [&] { + return p_chan->RecvAll( + EraseType(common::Span{&r, static_cast(1)})); + } << [&] { return p_chan->Block(); }; ASSERT_TRUE(rc.OK()) << rc.Report(); ASSERT_EQ(r, i - 1); } diff --git a/tests/cpp/collective/test_nccl_device_communicator.cu b/tests/cpp/collective/test_nccl_device_communicator.cu index 3d7b1efc8..47e86220d 100644 --- a/tests/cpp/collective/test_nccl_device_communicator.cu +++ b/tests/cpp/collective/test_nccl_device_communicator.cu @@ -23,7 +23,7 @@ TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) { TEST(NcclDeviceCommunicatorSimpleTest, SystemError) { auto stub = std::make_shared(DefaultNcclName()); - auto rc = GetNCCLResult(stub, ncclSystemError); + auto rc = stub->GetNcclResult(ncclSystemError); auto msg = rc.Report(); ASSERT_TRUE(msg.find("environment variables") != std::string::npos); }