Cleanup code for distributed training. (#9805)
* Cleanup code for distributed training. - Merge `GetNcclResult` into nccl stub. - Split up utilities from the main dask module. - Let Channel return `Result` to accommodate nccl channel. - Remove old `use_label_encoder` parameter.
This commit is contained in:
parent
e9260de3f3
commit
8fe1a2213c
@ -94,6 +94,8 @@ from xgboost.sklearn import (
|
|||||||
from xgboost.tracker import RabitTracker, get_host_ip
|
from xgboost.tracker import RabitTracker, get_host_ip
|
||||||
from xgboost.training import train as worker_train
|
from xgboost.training import train as worker_train
|
||||||
|
|
||||||
|
from .utils import get_n_threads
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import dask
|
import dask
|
||||||
import distributed
|
import distributed
|
||||||
@ -908,6 +910,34 @@ async def _check_workers_are_alive(
|
|||||||
raise RuntimeError(f"Missing required workers: {missing_workers}")
|
raise RuntimeError(f"Missing required workers: {missing_workers}")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dmatrices(
|
||||||
|
train_ref: dict,
|
||||||
|
train_id: int,
|
||||||
|
*refs: dict,
|
||||||
|
evals_id: Sequence[int],
|
||||||
|
evals_name: Sequence[str],
|
||||||
|
n_threads: int,
|
||||||
|
) -> Tuple[DMatrix, List[Tuple[DMatrix, str]]]:
|
||||||
|
Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads)
|
||||||
|
evals: List[Tuple[DMatrix, str]] = []
|
||||||
|
for i, ref in enumerate(refs):
|
||||||
|
if evals_id[i] == train_id:
|
||||||
|
evals.append((Xy, evals_name[i]))
|
||||||
|
continue
|
||||||
|
if ref.get("ref", None) is not None:
|
||||||
|
if ref["ref"] != train_id:
|
||||||
|
raise ValueError(
|
||||||
|
"The training DMatrix should be used as a reference to evaluation"
|
||||||
|
" `QuantileDMatrix`."
|
||||||
|
)
|
||||||
|
del ref["ref"]
|
||||||
|
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads, ref=Xy)
|
||||||
|
else:
|
||||||
|
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads)
|
||||||
|
evals.append((eval_Xy, evals_name[i]))
|
||||||
|
return Xy, evals
|
||||||
|
|
||||||
|
|
||||||
async def _train_async(
|
async def _train_async(
|
||||||
client: "distributed.Client",
|
client: "distributed.Client",
|
||||||
global_config: Dict[str, Any],
|
global_config: Dict[str, Any],
|
||||||
@ -940,41 +970,20 @@ async def _train_async(
|
|||||||
) -> Optional[TrainReturnT]:
|
) -> Optional[TrainReturnT]:
|
||||||
worker = distributed.get_worker()
|
worker = distributed.get_worker()
|
||||||
local_param = parameters.copy()
|
local_param = parameters.copy()
|
||||||
n_threads = 0
|
n_threads = get_n_threads(local_param, worker)
|
||||||
# dask worker nthreads, "state" is available in 2022.6.1
|
|
||||||
dwnt = worker.state.nthreads if hasattr(worker, "state") else worker.nthreads
|
|
||||||
for p in ["nthread", "n_jobs"]:
|
|
||||||
if (
|
|
||||||
local_param.get(p, None) is not None
|
|
||||||
and local_param.get(p, dwnt) != dwnt
|
|
||||||
):
|
|
||||||
LOGGER.info("Overriding `nthreads` defined in dask worker.")
|
|
||||||
n_threads = local_param[p]
|
|
||||||
break
|
|
||||||
if n_threads == 0 or n_threads is None:
|
|
||||||
n_threads = dwnt
|
|
||||||
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
|
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
|
||||||
|
|
||||||
local_history: TrainingCallback.EvalsLog = {}
|
local_history: TrainingCallback.EvalsLog = {}
|
||||||
|
|
||||||
with CommunicatorContext(**rabit_args), config.config_context(**global_config):
|
with CommunicatorContext(**rabit_args), config.config_context(**global_config):
|
||||||
Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads)
|
Xy, evals = _get_dmatrices(
|
||||||
evals: List[Tuple[DMatrix, str]] = []
|
train_ref,
|
||||||
for i, ref in enumerate(refs):
|
train_id,
|
||||||
if evals_id[i] == train_id:
|
*refs,
|
||||||
evals.append((Xy, evals_name[i]))
|
evals_id=evals_id,
|
||||||
continue
|
evals_name=evals_name,
|
||||||
if ref.get("ref", None) is not None:
|
n_threads=n_threads,
|
||||||
if ref["ref"] != train_id:
|
)
|
||||||
raise ValueError(
|
|
||||||
"The training DMatrix should be used as a reference"
|
|
||||||
" to evaluation `QuantileDMatrix`."
|
|
||||||
)
|
|
||||||
del ref["ref"]
|
|
||||||
eval_Xy = _dmatrix_from_list_of_parts(
|
|
||||||
**ref, nthread=n_threads, ref=Xy
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads)
|
|
||||||
evals.append((eval_Xy, evals_name[i]))
|
|
||||||
|
|
||||||
booster = worker_train(
|
booster = worker_train(
|
||||||
params=local_param,
|
params=local_param,
|
||||||
|
|||||||
24
python-package/xgboost/dask/utils.py
Normal file
24
python-package/xgboost/dask/utils.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
"""Utilities for the XGBoost Dask interface."""
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger("[xgboost.dask]")
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import distributed
|
||||||
|
|
||||||
|
|
||||||
|
def get_n_threads(local_param: Dict[str, Any], worker: "distributed.Worker") -> int:
|
||||||
|
"""Get the number of threads from a worker and the user-supplied parameters."""
|
||||||
|
# dask worker nthreads, "state" is available in 2022.6.1
|
||||||
|
dwnt = worker.state.nthreads if hasattr(worker, "state") else worker.nthreads
|
||||||
|
n_threads = None
|
||||||
|
for p in ["nthread", "n_jobs"]:
|
||||||
|
if local_param.get(p, None) is not None and local_param.get(p, dwnt) != dwnt:
|
||||||
|
LOGGER.info("Overriding `nthreads` defined in dask worker.")
|
||||||
|
n_threads = local_param[p]
|
||||||
|
break
|
||||||
|
if n_threads == 0 or n_threads is None:
|
||||||
|
n_threads = dwnt
|
||||||
|
return n_threads
|
||||||
@ -808,7 +808,6 @@ class XGBModel(XGBModelBase):
|
|||||||
"kwargs",
|
"kwargs",
|
||||||
"missing",
|
"missing",
|
||||||
"n_estimators",
|
"n_estimators",
|
||||||
"use_label_encoder",
|
|
||||||
"enable_categorical",
|
"enable_categorical",
|
||||||
"early_stopping_rounds",
|
"early_stopping_rounds",
|
||||||
"callbacks",
|
"callbacks",
|
||||||
|
|||||||
@ -138,7 +138,6 @@ _inverse_pyspark_param_alias_map = {v: k for k, v in _pyspark_param_alias_map.it
|
|||||||
_unsupported_xgb_params = [
|
_unsupported_xgb_params = [
|
||||||
"gpu_id", # we have "device" pyspark param instead.
|
"gpu_id", # we have "device" pyspark param instead.
|
||||||
"enable_categorical", # Use feature_types param to specify categorical feature instead
|
"enable_categorical", # Use feature_types param to specify categorical feature instead
|
||||||
"use_label_encoder",
|
|
||||||
"n_jobs", # Do not allow user to set it, will use `spark.task.cpus` value instead.
|
"n_jobs", # Do not allow user to set it, will use `spark.task.cpus` value instead.
|
||||||
"nthread", # Ditto
|
"nthread", # Ditto
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2019-2023 by XGBoost Contributors
|
* Copyright 2019-2023, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <thrust/transform.h> // for transform
|
#include <thrust/transform.h> // for transform
|
||||||
|
|
||||||
@ -15,6 +15,9 @@
|
|||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
#include "xgboost/learner.h"
|
#include "xgboost/learner.h"
|
||||||
|
#if defined(XGBOOST_USE_NCCL)
|
||||||
|
#include <nccl.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
void XGBBuildInfoDevice(Json *p_info) {
|
void XGBBuildInfoDevice(Json *p_info) {
|
||||||
|
|||||||
@ -26,18 +26,19 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (std::int32_t r = 0; r < world; ++r) {
|
for (std::int32_t r = 0; r < world; ++r) {
|
||||||
auto send_rank = (rank + world - r + worker_off) % world;
|
auto rc = Success() << [&] {
|
||||||
auto send_off = send_rank * segment_size;
|
auto send_rank = (rank + world - r + worker_off) % world;
|
||||||
send_off = std::min(send_off, data.size_bytes());
|
auto send_off = send_rank * segment_size;
|
||||||
auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off));
|
send_off = std::min(send_off, data.size_bytes());
|
||||||
next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
|
auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off));
|
||||||
|
return next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
|
||||||
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
|
} << [&] {
|
||||||
auto recv_off = recv_rank * segment_size;
|
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
|
||||||
recv_off = std::min(recv_off, data.size_bytes());
|
auto recv_off = recv_rank * segment_size;
|
||||||
auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off));
|
recv_off = std::min(recv_off, data.size_bytes());
|
||||||
prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off));
|
||||||
auto rc = prev_ch->Block();
|
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||||
|
} << [&] { return prev_ch->Block(); };
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
@ -78,19 +79,19 @@ namespace detail {
|
|||||||
auto next_ch = comm.Chan(next);
|
auto next_ch = comm.Chan(next);
|
||||||
|
|
||||||
for (std::int32_t r = 0; r < world; ++r) {
|
for (std::int32_t r = 0; r < world; ++r) {
|
||||||
auto send_rank = (rank + world - r) % world;
|
auto rc = Success() << [&] {
|
||||||
auto send_off = offset[send_rank];
|
auto send_rank = (rank + world - r) % world;
|
||||||
auto send_size = sizes[send_rank];
|
auto send_off = offset[send_rank];
|
||||||
auto send_seg = erased_result.subspan(send_off, send_size);
|
auto send_size = sizes[send_rank];
|
||||||
next_ch->SendAll(send_seg);
|
auto send_seg = erased_result.subspan(send_off, send_size);
|
||||||
|
return next_ch->SendAll(send_seg);
|
||||||
auto recv_rank = (rank + world - r - 1) % world;
|
} << [&] {
|
||||||
auto recv_off = offset[recv_rank];
|
auto recv_rank = (rank + world - r - 1) % world;
|
||||||
auto recv_size = sizes[recv_rank];
|
auto recv_off = offset[recv_rank];
|
||||||
auto recv_seg = erased_result.subspan(recv_off, recv_size);
|
auto recv_size = sizes[recv_rank];
|
||||||
prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
auto recv_seg = erased_result.subspan(recv_off, recv_size);
|
||||||
|
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||||
auto rc = prev_ch->Block();
|
} << [&] { return prev_ch->Block(); };
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -37,7 +37,10 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
|||||||
auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg);
|
auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg);
|
||||||
auto send_seg = data.subspan(send_off, seg_nbytes);
|
auto send_seg = data.subspan(send_off, seg_nbytes);
|
||||||
|
|
||||||
next_ch->SendAll(send_seg);
|
auto rc = next_ch->SendAll(send_seg);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
// receive from ring prev
|
// receive from ring prev
|
||||||
auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg;
|
auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg;
|
||||||
@ -47,8 +50,7 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
|||||||
auto recv_seg = data.subspan(recv_off, seg_nbytes);
|
auto recv_seg = data.subspan(recv_off, seg_nbytes);
|
||||||
auto seg = s_buf.subspan(0, recv_seg.size());
|
auto seg = s_buf.subspan(0, recv_seg.size());
|
||||||
|
|
||||||
prev_ch->RecvAll(seg);
|
rc = std::move(rc) << [&] { return prev_ch->RecvAll(seg); } << [&] { return comm.Block(); };
|
||||||
auto rc = comm.Block();
|
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -62,8 +62,8 @@ Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t
|
|||||||
|
|
||||||
if (shifted_rank != 0) { // not root
|
if (shifted_rank != 0) { // not root
|
||||||
auto parent = ShiftRight(ShiftedParentRank(shifted_rank, depth), world, root);
|
auto parent = ShiftRight(ShiftedParentRank(shifted_rank, depth), world, root);
|
||||||
comm.Chan(parent)->RecvAll(data);
|
auto rc = Success() << [&] { return comm.Chan(parent)->RecvAll(data); }
|
||||||
auto rc = comm.Chan(parent)->Block();
|
<< [&] { return comm.Chan(parent)->Block(); };
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return Fail("broadcast failed.", std::move(rc));
|
return Fail("broadcast failed.", std::move(rc));
|
||||||
}
|
}
|
||||||
@ -75,7 +75,10 @@ Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t
|
|||||||
auto sft_peer = shifted_rank + (1 << i);
|
auto sft_peer = shifted_rank + (1 << i);
|
||||||
auto peer = ShiftRight(sft_peer, world, root);
|
auto peer = ShiftRight(sft_peer, world, root);
|
||||||
CHECK_NE(peer, root);
|
CHECK_NE(peer, root);
|
||||||
comm.Chan(peer)->SendAll(data);
|
auto rc = comm.Chan(peer)->SendAll(data);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -79,8 +79,8 @@ void RunBitwiseAllreduce(dh::CUDAStreamView stream, common::Span<std::int8_t> ou
|
|||||||
|
|
||||||
// First gather data from all the workers.
|
// First gather data from all the workers.
|
||||||
CHECK(handle);
|
CHECK(handle);
|
||||||
auto rc = GetNCCLResult(stub, stub->Allgather(data.data(), device_buffer, data.size(), ncclInt8,
|
auto rc =
|
||||||
handle, pcomm->Stream()));
|
stub->Allgather(data.data(), device_buffer, data.size(), ncclInt8, handle, pcomm->Stream());
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
@ -140,9 +140,8 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
|
|||||||
return DispatchDType(type, [=](auto t) {
|
return DispatchDType(type, [=](auto t) {
|
||||||
using T = decltype(t);
|
using T = decltype(t);
|
||||||
auto rdata = common::RestoreType<T>(data);
|
auto rdata = common::RestoreType<T>(data);
|
||||||
auto rc = stub->Allreduce(data.data(), data.data(), rdata.size(), GetNCCLType(type),
|
return stub->Allreduce(data.data(), data.data(), rdata.size(), GetNCCLType(type),
|
||||||
GetNCCLRedOp(op), nccl->Handle(), nccl->Stream());
|
GetNCCLRedOp(op), nccl->Handle(), nccl->Stream());
|
||||||
return GetNCCLResult(stub, rc);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} << [&] { return nccl->Block(); };
|
} << [&] { return nccl->Block(); };
|
||||||
@ -158,8 +157,8 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
|
|||||||
auto stub = nccl->Stub();
|
auto stub = nccl->Stub();
|
||||||
|
|
||||||
return Success() << [&] {
|
return Success() << [&] {
|
||||||
return GetNCCLResult(stub, stub->Broadcast(data.data(), data.data(), data.size_bytes(),
|
return stub->Broadcast(data.data(), data.data(), data.size_bytes(), ncclInt8, root,
|
||||||
ncclInt8, root, nccl->Handle(), nccl->Stream()));
|
nccl->Handle(), nccl->Stream());
|
||||||
} << [&] { return nccl->Block(); };
|
} << [&] { return nccl->Block(); };
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -174,8 +173,8 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
|
|||||||
|
|
||||||
auto send = data.subspan(comm.Rank() * size, size);
|
auto send = data.subspan(comm.Rank() * size, size);
|
||||||
return Success() << [&] {
|
return Success() << [&] {
|
||||||
return GetNCCLResult(stub, stub->Allgather(send.data(), data.data(), size, ncclInt8,
|
return stub->Allgather(send.data(), data.data(), size, ncclInt8, nccl->Handle(),
|
||||||
nccl->Handle(), nccl->Stream()));
|
nccl->Stream());
|
||||||
} << [&] { return nccl->Block(); };
|
} << [&] { return nccl->Block(); };
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -188,19 +187,19 @@ namespace cuda_impl {
|
|||||||
Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const> data,
|
Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const> data,
|
||||||
common::Span<std::int64_t const> sizes, common::Span<std::int8_t> recv) {
|
common::Span<std::int64_t const> sizes, common::Span<std::int8_t> recv) {
|
||||||
auto stub = comm->Stub();
|
auto stub = comm->Stub();
|
||||||
return Success() << [&stub] { return GetNCCLResult(stub, stub->GroupStart()); } << [&] {
|
return Success() << [&stub] { return stub->GroupStart(); } << [&] {
|
||||||
std::size_t offset = 0;
|
std::size_t offset = 0;
|
||||||
for (std::int32_t r = 0; r < comm->World(); ++r) {
|
for (std::int32_t r = 0; r < comm->World(); ++r) {
|
||||||
auto as_bytes = sizes[r];
|
auto as_bytes = sizes[r];
|
||||||
auto rc = stub->Broadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes,
|
auto rc = stub->Broadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes,
|
||||||
ncclInt8, r, comm->Handle(), dh::DefaultStream());
|
ncclInt8, r, comm->Handle(), dh::DefaultStream());
|
||||||
if (rc != ncclSuccess) {
|
if (!rc.OK()) {
|
||||||
return GetNCCLResult(stub, rc);
|
return rc;
|
||||||
}
|
}
|
||||||
offset += as_bytes;
|
offset += as_bytes;
|
||||||
}
|
}
|
||||||
return Success();
|
return Success();
|
||||||
} << [&] { return GetNCCLResult(stub, stub->GroupEnd()); };
|
} << [&] { return stub->GroupEnd(); };
|
||||||
}
|
}
|
||||||
} // namespace cuda_impl
|
} // namespace cuda_impl
|
||||||
|
|
||||||
@ -217,7 +216,7 @@ Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const>
|
|||||||
|
|
||||||
switch (algo) {
|
switch (algo) {
|
||||||
case AllgatherVAlgo::kRing: {
|
case AllgatherVAlgo::kRing: {
|
||||||
return Success() << [&] { return GetNCCLResult(stub, stub->GroupStart()); } << [&] {
|
return Success() << [&] { return stub->GroupStart(); } << [&] {
|
||||||
// get worker offset
|
// get worker offset
|
||||||
detail::AllgatherVOffset(sizes, recv_segments);
|
detail::AllgatherVOffset(sizes, recv_segments);
|
||||||
// copy data
|
// copy data
|
||||||
@ -228,7 +227,7 @@ Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const>
|
|||||||
}
|
}
|
||||||
return detail::RingAllgatherV(comm, sizes, recv_segments, recv);
|
return detail::RingAllgatherV(comm, sizes, recv_segments, recv);
|
||||||
} << [&] {
|
} << [&] {
|
||||||
return GetNCCLResult(stub, stub->GroupEnd());
|
return stub->GroupEnd();
|
||||||
} << [&] { return nccl->Block(); };
|
} << [&] { return nccl->Block(); };
|
||||||
}
|
}
|
||||||
case AllgatherVAlgo::kBcast: {
|
case AllgatherVAlgo::kBcast: {
|
||||||
|
|||||||
@ -26,7 +26,7 @@ Result GetUniqueId(Comm const& comm, std::shared_ptr<NcclStub> stub, std::shared
|
|||||||
static const int kRootRank = 0;
|
static const int kRootRank = 0;
|
||||||
ncclUniqueId id;
|
ncclUniqueId id;
|
||||||
if (comm.Rank() == kRootRank) {
|
if (comm.Rank() == kRootRank) {
|
||||||
auto rc = GetNCCLResult(stub, stub->GetUniqueId(&id));
|
auto rc = stub->GetUniqueId(&id);
|
||||||
CHECK(rc.OK()) << rc.Report();
|
CHECK(rc.OK()) << rc.Report();
|
||||||
}
|
}
|
||||||
auto rc = coll->Broadcast(
|
auto rc = coll->Broadcast(
|
||||||
@ -99,12 +99,10 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
|||||||
<< "Multiple processes within communication group running on same CUDA "
|
<< "Multiple processes within communication group running on same CUDA "
|
||||||
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||||
|
|
||||||
rc = std::move(rc) << [&] {
|
rc = std::move(rc) << [&] { return GetUniqueId(root, this->stub_, pimpl, &nccl_unique_id_); } <<
|
||||||
return GetUniqueId(root, this->stub_, pimpl, &nccl_unique_id_);
|
[&] {
|
||||||
} << [&] {
|
return this->stub_->CommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank());
|
||||||
return GetNCCLResult(this->stub_, this->stub_->CommInitRank(&nccl_comm_, root.World(),
|
};
|
||||||
nccl_unique_id_, root.Rank()));
|
|
||||||
};
|
|
||||||
CHECK(rc.OK()) << rc.Report();
|
CHECK(rc.OK()) << rc.Report();
|
||||||
|
|
||||||
for (std::int32_t r = 0; r < root.World(); ++r) {
|
for (std::int32_t r = 0; r < root.World(); ++r) {
|
||||||
@ -115,7 +113,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
|||||||
|
|
||||||
NCCLComm::~NCCLComm() {
|
NCCLComm::~NCCLComm() {
|
||||||
if (nccl_comm_) {
|
if (nccl_comm_) {
|
||||||
auto rc = GetNCCLResult(stub_, stub_->CommDestroy(nccl_comm_));
|
auto rc = stub_->CommDestroy(nccl_comm_);
|
||||||
CHECK(rc.OK()) << rc.Report();
|
CHECK(rc.OK()) << rc.Report();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -52,25 +52,6 @@ class NCCLComm : public Comm {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
inline Result GetNCCLResult(std::shared_ptr<NcclStub> stub, ncclResult_t code) {
|
|
||||||
if (code == ncclSuccess) {
|
|
||||||
return Success();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "NCCL failure: " << stub->GetErrorString(code) << ".";
|
|
||||||
if (code == ncclUnhandledCudaError) {
|
|
||||||
// nccl usually preserves the last error so we can get more details.
|
|
||||||
auto err = cudaPeekAtLastError();
|
|
||||||
ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n";
|
|
||||||
} else if (code == ncclSystemError) {
|
|
||||||
ss << " This might be caused by a network configuration issue. Please consider specifying "
|
|
||||||
"the network interface for NCCL via environment variables listed in its reference: "
|
|
||||||
"`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n";
|
|
||||||
}
|
|
||||||
return Fail(ss.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
class NCCLChannel : public Channel {
|
class NCCLChannel : public Channel {
|
||||||
std::int32_t rank_{-1};
|
std::int32_t rank_{-1};
|
||||||
ncclComm_t nccl_comm_{};
|
ncclComm_t nccl_comm_{};
|
||||||
@ -86,13 +67,11 @@ class NCCLChannel : public Channel {
|
|||||||
Channel{comm, nullptr},
|
Channel{comm, nullptr},
|
||||||
stream_{stream} {}
|
stream_{stream} {}
|
||||||
|
|
||||||
void SendAll(std::int8_t const* ptr, std::size_t n) override {
|
[[nodiscard]] Result SendAll(std::int8_t const* ptr, std::size_t n) override {
|
||||||
auto rc = GetNCCLResult(stub_, stub_->Send(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
|
return stub_->Send(ptr, n, ncclInt8, rank_, nccl_comm_, stream_);
|
||||||
CHECK(rc.OK()) << rc.Report();
|
|
||||||
}
|
}
|
||||||
void RecvAll(std::int8_t* ptr, std::size_t n) override {
|
[[nodiscard]] Result RecvAll(std::int8_t* ptr, std::size_t n) override {
|
||||||
auto rc = GetNCCLResult(stub_, stub_->Recv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
|
return stub_->Recv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_);
|
||||||
CHECK(rc.OK()) << rc.Report();
|
|
||||||
}
|
}
|
||||||
[[nodiscard]] Result Block() override {
|
[[nodiscard]] Result Block() override {
|
||||||
auto rc = stream_.Sync(false);
|
auto rc = stream_.Sync(false);
|
||||||
|
|||||||
@ -135,21 +135,25 @@ class Channel {
|
|||||||
explicit Channel(Comm const& comm, std::shared_ptr<TCPSocket> sock)
|
explicit Channel(Comm const& comm, std::shared_ptr<TCPSocket> sock)
|
||||||
: sock_{std::move(sock)}, comm_{comm} {}
|
: sock_{std::move(sock)}, comm_{comm} {}
|
||||||
|
|
||||||
virtual void SendAll(std::int8_t const* ptr, std::size_t n) {
|
[[nodiscard]] virtual Result SendAll(std::int8_t const* ptr, std::size_t n) {
|
||||||
Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast<std::int8_t*>(ptr), n, sock_.get(), 0};
|
Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast<std::int8_t*>(ptr), n, sock_.get(), 0};
|
||||||
CHECK(sock_.get());
|
CHECK(sock_.get());
|
||||||
comm_.Submit(std::move(op));
|
comm_.Submit(std::move(op));
|
||||||
|
return Success();
|
||||||
}
|
}
|
||||||
void SendAll(common::Span<std::int8_t const> data) {
|
[[nodiscard]] Result SendAll(common::Span<std::int8_t const> data) {
|
||||||
this->SendAll(data.data(), data.size_bytes());
|
return this->SendAll(data.data(), data.size_bytes());
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual void RecvAll(std::int8_t* ptr, std::size_t n) {
|
[[nodiscard]] virtual Result RecvAll(std::int8_t* ptr, std::size_t n) {
|
||||||
Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0};
|
Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0};
|
||||||
CHECK(sock_.get());
|
CHECK(sock_.get());
|
||||||
comm_.Submit(std::move(op));
|
comm_.Submit(std::move(op));
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
[[nodiscard]] Result RecvAll(common::Span<std::int8_t> data) {
|
||||||
|
return this->RecvAll(data.data(), data.size_bytes());
|
||||||
}
|
}
|
||||||
void RecvAll(common::Span<std::int8_t> data) { this->RecvAll(data.data(), data.size_bytes()); }
|
|
||||||
|
|
||||||
[[nodiscard]] auto Socket() const { return sock_; }
|
[[nodiscard]] auto Socket() const { return sock_; }
|
||||||
[[nodiscard]] virtual Result Block() { return comm_.Block(); }
|
[[nodiscard]] virtual Result Block() { return comm_.Block(); }
|
||||||
|
|||||||
@ -46,8 +46,7 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sy
|
|||||||
|
|
||||||
nccl_unique_id_ = GetUniqueId();
|
nccl_unique_id_ = GetUniqueId();
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
auto rc =
|
auto rc = stub_->CommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_);
|
||||||
GetNCCLResult(stub_, stub_->CommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
|
|
||||||
CHECK(rc.OK()) << rc.Report();
|
CHECK(rc.OK()) << rc.Report();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -56,7 +55,7 @@ NcclDeviceCommunicator::~NcclDeviceCommunicator() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (nccl_comm_) {
|
if (nccl_comm_) {
|
||||||
auto rc = GetNCCLResult(stub_, stub_->CommDestroy(nccl_comm_));
|
auto rc = stub_->CommDestroy(nccl_comm_);
|
||||||
CHECK(rc.OK()) << rc.Report();
|
CHECK(rc.OK()) << rc.Report();
|
||||||
}
|
}
|
||||||
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||||
@ -143,9 +142,8 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
|
|||||||
auto *device_buffer = buffer.data().get();
|
auto *device_buffer = buffer.data().get();
|
||||||
|
|
||||||
// First gather data from all the workers.
|
// First gather data from all the workers.
|
||||||
auto rc = GetNCCLResult(
|
auto rc = stub_->Allgather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
|
||||||
stub_, stub_->Allgather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
|
nccl_comm_, dh::DefaultStream());
|
||||||
nccl_comm_, dh::DefaultStream()));
|
|
||||||
CHECK(rc.OK()) << rc.Report();
|
CHECK(rc.OK()) << rc.Report();
|
||||||
if (needs_sync_) {
|
if (needs_sync_) {
|
||||||
dh::DefaultStream().Sync();
|
dh::DefaultStream().Sync();
|
||||||
@ -178,9 +176,9 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co
|
|||||||
if (IsBitwiseOp(op)) {
|
if (IsBitwiseOp(op)) {
|
||||||
BitwiseAllReduce(send_receive_buffer, count, data_type, op);
|
BitwiseAllReduce(send_receive_buffer, count, data_type, op);
|
||||||
} else {
|
} else {
|
||||||
auto rc = GetNCCLResult(stub_, stub_->Allreduce(send_receive_buffer, send_receive_buffer, count,
|
auto rc = stub_->Allreduce(send_receive_buffer, send_receive_buffer, count,
|
||||||
GetNcclDataType(data_type), GetNcclRedOp(op),
|
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
|
||||||
nccl_comm_, dh::DefaultStream()));
|
dh::DefaultStream());
|
||||||
CHECK(rc.OK()) << rc.Report();
|
CHECK(rc.OK()) << rc.Report();
|
||||||
}
|
}
|
||||||
allreduce_bytes_ += count * GetTypeSize(data_type);
|
allreduce_bytes_ += count * GetTypeSize(data_type);
|
||||||
@ -194,8 +192,8 @@ void NcclDeviceCommunicator::AllGather(void const *send_buffer, void *receive_bu
|
|||||||
}
|
}
|
||||||
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
auto rc = GetNCCLResult(stub_, stub_->Allgather(send_buffer, receive_buffer, send_size, ncclInt8,
|
auto rc = stub_->Allgather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_,
|
||||||
nccl_comm_, dh::DefaultStream()));
|
dh::DefaultStream());
|
||||||
CHECK(rc.OK()) << rc.Report();
|
CHECK(rc.OK()) << rc.Report();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -216,19 +214,18 @@ void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_b
|
|||||||
receive_buffer->resize(total_bytes);
|
receive_buffer->resize(total_bytes);
|
||||||
|
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
auto rc = Success() << [&] { return GetNCCLResult(stub_, stub_->GroupStart()); } << [&] {
|
auto rc = Success() << [&] { return stub_->GroupStart(); } << [&] {
|
||||||
for (int32_t i = 0; i < world_size_; ++i) {
|
for (int32_t i = 0; i < world_size_; ++i) {
|
||||||
size_t as_bytes = segments->at(i);
|
size_t as_bytes = segments->at(i);
|
||||||
auto rc = GetNCCLResult(
|
auto rc = stub_->Broadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
||||||
stub_, stub_->Broadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
ncclChar, i, nccl_comm_, dh::DefaultStream());
|
||||||
ncclChar, i, nccl_comm_, dh::DefaultStream()));
|
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
offset += as_bytes;
|
offset += as_bytes;
|
||||||
}
|
}
|
||||||
return Success();
|
return Success();
|
||||||
} << [&] { return GetNCCLResult(stub_, stub_->GroupEnd()); };
|
} << [&] { return stub_->GroupEnd(); };
|
||||||
}
|
}
|
||||||
|
|
||||||
void NcclDeviceCommunicator::Synchronize() {
|
void NcclDeviceCommunicator::Synchronize() {
|
||||||
|
|||||||
@ -66,7 +66,7 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
|||||||
static const int kRootRank = 0;
|
static const int kRootRank = 0;
|
||||||
ncclUniqueId id;
|
ncclUniqueId id;
|
||||||
if (rank_ == kRootRank) {
|
if (rank_ == kRootRank) {
|
||||||
auto rc = GetNCCLResult(stub_, stub_->GetUniqueId(&id));
|
auto rc = stub_->GetUniqueId(&id);
|
||||||
CHECK(rc.OK()) << rc.Report();
|
CHECK(rc.OK()) << rc.Report();
|
||||||
}
|
}
|
||||||
Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
|
Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
|
||||||
|
|||||||
@ -4,9 +4,12 @@
|
|||||||
#if defined(XGBOOST_USE_NCCL)
|
#if defined(XGBOOST_USE_NCCL)
|
||||||
#include "nccl_stub.h"
|
#include "nccl_stub.h"
|
||||||
|
|
||||||
#include <cuda.h> // for CUDA_VERSION
|
#include <cuda.h> // for CUDA_VERSION
|
||||||
#include <dlfcn.h> // for dlclose, dlsym, dlopen
|
#include <cuda_runtime_api.h> // for cudaPeekAtLastError
|
||||||
|
#include <dlfcn.h> // for dlclose, dlsym, dlopen
|
||||||
#include <nccl.h>
|
#include <nccl.h>
|
||||||
|
#include <thrust/system/cuda/error.h> // for cuda_category
|
||||||
|
#include <thrust/system_error.h> // for system_error
|
||||||
|
|
||||||
#include <cstdint> // for int32_t
|
#include <cstdint> // for int32_t
|
||||||
#include <sstream> // for stringstream
|
#include <sstream> // for stringstream
|
||||||
@ -16,6 +19,25 @@
|
|||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
|
Result NcclStub::GetNcclResult(ncclResult_t code) const {
|
||||||
|
if (code == ncclSuccess) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "NCCL failure: " << this->GetErrorString(code) << ".";
|
||||||
|
if (code == ncclUnhandledCudaError) {
|
||||||
|
// nccl usually preserves the last error so we can get more details.
|
||||||
|
auto err = cudaPeekAtLastError();
|
||||||
|
ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n";
|
||||||
|
} else if (code == ncclSystemError) {
|
||||||
|
ss << " This might be caused by a network configuration issue. Please consider specifying "
|
||||||
|
"the network interface for NCCL via environment variables listed in its reference: "
|
||||||
|
"`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n";
|
||||||
|
}
|
||||||
|
return Fail(ss.str());
|
||||||
|
}
|
||||||
|
|
||||||
NcclStub::NcclStub(StringView path) : path_{std::move(path)} {
|
NcclStub::NcclStub(StringView path) : path_{std::move(path)} {
|
||||||
#if defined(XGBOOST_USE_DLOPEN_NCCL)
|
#if defined(XGBOOST_USE_DLOPEN_NCCL)
|
||||||
CHECK(!path_.empty()) << "Empty path for NCCL.";
|
CHECK(!path_.empty()) << "Empty path for NCCL.";
|
||||||
|
|||||||
@ -8,9 +8,13 @@
|
|||||||
|
|
||||||
#include <string> // for string
|
#include <string> // for string
|
||||||
|
|
||||||
#include "xgboost/string_view.h" // for StringView
|
#include "xgboost/collective/result.h" // for Result
|
||||||
|
#include "xgboost/string_view.h" // for StringView
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
|
/**
|
||||||
|
* @brief A stub for NCCL to facilitate dynamic loading.
|
||||||
|
*/
|
||||||
class NcclStub {
|
class NcclStub {
|
||||||
#if defined(XGBOOST_USE_DLOPEN_NCCL)
|
#if defined(XGBOOST_USE_DLOPEN_NCCL)
|
||||||
void* handle_{nullptr};
|
void* handle_{nullptr};
|
||||||
@ -30,61 +34,48 @@ class NcclStub {
|
|||||||
decltype(ncclGetErrorString)* get_error_string_{nullptr};
|
decltype(ncclGetErrorString)* get_error_string_{nullptr};
|
||||||
decltype(ncclGetVersion)* get_version_{nullptr};
|
decltype(ncclGetVersion)* get_version_{nullptr};
|
||||||
|
|
||||||
|
public:
|
||||||
|
Result GetNcclResult(ncclResult_t code) const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit NcclStub(StringView path);
|
explicit NcclStub(StringView path);
|
||||||
~NcclStub();
|
~NcclStub();
|
||||||
|
|
||||||
[[nodiscard]] ncclResult_t Allreduce(const void* sendbuff, void* recvbuff, size_t count,
|
[[nodiscard]] Result Allreduce(const void* sendbuff, void* recvbuff, size_t count,
|
||||||
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||||
cudaStream_t stream) const {
|
cudaStream_t stream) const {
|
||||||
CHECK(allreduce_);
|
return this->GetNcclResult(allreduce_(sendbuff, recvbuff, count, datatype, op, comm, stream));
|
||||||
return this->allreduce_(sendbuff, recvbuff, count, datatype, op, comm, stream);
|
|
||||||
}
|
}
|
||||||
[[nodiscard]] ncclResult_t Broadcast(const void* sendbuff, void* recvbuff, size_t count,
|
[[nodiscard]] Result Broadcast(const void* sendbuff, void* recvbuff, size_t count,
|
||||||
ncclDataType_t datatype, int root, ncclComm_t comm,
|
ncclDataType_t datatype, int root, ncclComm_t comm,
|
||||||
cudaStream_t stream) const {
|
cudaStream_t stream) const {
|
||||||
CHECK(broadcast_);
|
return this->GetNcclResult(broadcast_(sendbuff, recvbuff, count, datatype, root, comm, stream));
|
||||||
return this->broadcast_(sendbuff, recvbuff, count, datatype, root, comm, stream);
|
|
||||||
}
|
}
|
||||||
[[nodiscard]] ncclResult_t Allgather(const void* sendbuff, void* recvbuff, size_t sendcount,
|
[[nodiscard]] Result Allgather(const void* sendbuff, void* recvbuff, size_t sendcount,
|
||||||
ncclDataType_t datatype, ncclComm_t comm,
|
ncclDataType_t datatype, ncclComm_t comm,
|
||||||
cudaStream_t stream) const {
|
cudaStream_t stream) const {
|
||||||
CHECK(allgather_);
|
return this->GetNcclResult(allgather_(sendbuff, recvbuff, sendcount, datatype, comm, stream));
|
||||||
return this->allgather_(sendbuff, recvbuff, sendcount, datatype, comm, stream);
|
|
||||||
}
|
}
|
||||||
[[nodiscard]] ncclResult_t CommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId,
|
[[nodiscard]] Result CommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId,
|
||||||
int rank) const {
|
int rank) const {
|
||||||
CHECK(comm_init_rank_);
|
return this->GetNcclResult(this->comm_init_rank_(comm, nranks, commId, rank));
|
||||||
return this->comm_init_rank_(comm, nranks, commId, rank);
|
|
||||||
}
|
}
|
||||||
[[nodiscard]] ncclResult_t CommDestroy(ncclComm_t comm) const {
|
[[nodiscard]] Result CommDestroy(ncclComm_t comm) const {
|
||||||
CHECK(comm_destroy_);
|
return this->GetNcclResult(comm_destroy_(comm));
|
||||||
return this->comm_destroy_(comm);
|
|
||||||
}
|
}
|
||||||
|
[[nodiscard]] Result GetUniqueId(ncclUniqueId* uniqueId) const {
|
||||||
[[nodiscard]] ncclResult_t GetUniqueId(ncclUniqueId* uniqueId) const {
|
return this->GetNcclResult(get_uniqueid_(uniqueId));
|
||||||
CHECK(get_uniqueid_);
|
|
||||||
return this->get_uniqueid_(uniqueId);
|
|
||||||
}
|
}
|
||||||
[[nodiscard]] ncclResult_t Send(const void* sendbuff, size_t count, ncclDataType_t datatype,
|
[[nodiscard]] Result Send(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer,
|
||||||
int peer, ncclComm_t comm, cudaStream_t stream) {
|
ncclComm_t comm, cudaStream_t stream) {
|
||||||
CHECK(send_);
|
return this->GetNcclResult(send_(sendbuff, count, datatype, peer, comm, stream));
|
||||||
return send_(sendbuff, count, datatype, peer, comm, stream);
|
|
||||||
}
|
}
|
||||||
[[nodiscard]] ncclResult_t Recv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer,
|
[[nodiscard]] Result Recv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer,
|
||||||
ncclComm_t comm, cudaStream_t stream) const {
|
ncclComm_t comm, cudaStream_t stream) const {
|
||||||
CHECK(recv_);
|
return this->GetNcclResult(recv_(recvbuff, count, datatype, peer, comm, stream));
|
||||||
return recv_(recvbuff, count, datatype, peer, comm, stream);
|
|
||||||
}
|
}
|
||||||
[[nodiscard]] ncclResult_t GroupStart() const {
|
[[nodiscard]] Result GroupStart() const { return this->GetNcclResult(group_start_()); }
|
||||||
CHECK(group_start_);
|
[[nodiscard]] Result GroupEnd() const { return this->GetNcclResult(group_end_()); }
|
||||||
return group_start_();
|
|
||||||
}
|
|
||||||
[[nodiscard]] ncclResult_t GroupEnd() const {
|
|
||||||
CHECK(group_end_);
|
|
||||||
return group_end_();
|
|
||||||
}
|
|
||||||
|
|
||||||
[[nodiscard]] const char* GetErrorString(ncclResult_t result) const {
|
[[nodiscard]] const char* GetErrorString(ncclResult_t result) const {
|
||||||
return get_error_string_(result);
|
return get_error_string_(result);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -36,10 +36,6 @@
|
|||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
#include "xgboost/span.h"
|
#include "xgboost/span.h"
|
||||||
|
|
||||||
#ifdef XGBOOST_USE_NCCL
|
|
||||||
#include "nccl.h"
|
|
||||||
#endif // XGBOOST_USE_NCCL
|
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||||
#include "rmm/mr/device/per_device_resource.hpp"
|
#include "rmm/mr/device/per_device_resource.hpp"
|
||||||
#include "rmm/mr/device/thrust_allocator_adaptor.hpp"
|
#include "rmm/mr/device/thrust_allocator_adaptor.hpp"
|
||||||
|
|||||||
@ -25,15 +25,18 @@ TEST_F(CommTest, Channel) {
|
|||||||
WorkerForTest worker{host, port, timeout, n_workers, i};
|
WorkerForTest worker{host, port, timeout, n_workers, i};
|
||||||
if (i % 2 == 0) {
|
if (i % 2 == 0) {
|
||||||
auto p_chan = worker.Comm().Chan(i + 1);
|
auto p_chan = worker.Comm().Chan(i + 1);
|
||||||
p_chan->SendAll(
|
auto rc = Success() << [&] {
|
||||||
EraseType(common::Span<std::int32_t const>{&i, static_cast<std::size_t>(1)}));
|
return p_chan->SendAll(
|
||||||
auto rc = p_chan->Block();
|
EraseType(common::Span<std::int32_t const>{&i, static_cast<std::size_t>(1)}));
|
||||||
|
} << [&] { return p_chan->Block(); };
|
||||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
} else {
|
} else {
|
||||||
auto p_chan = worker.Comm().Chan(i - 1);
|
auto p_chan = worker.Comm().Chan(i - 1);
|
||||||
std::int32_t r{-1};
|
std::int32_t r{-1};
|
||||||
p_chan->RecvAll(EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
|
auto rc = Success() << [&] {
|
||||||
auto rc = p_chan->Block();
|
return p_chan->RecvAll(
|
||||||
|
EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
|
||||||
|
} << [&] { return p_chan->Block(); };
|
||||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
ASSERT_EQ(r, i - 1);
|
ASSERT_EQ(r, i - 1);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -23,7 +23,7 @@ TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) {
|
|||||||
|
|
||||||
TEST(NcclDeviceCommunicatorSimpleTest, SystemError) {
|
TEST(NcclDeviceCommunicatorSimpleTest, SystemError) {
|
||||||
auto stub = std::make_shared<NcclStub>(DefaultNcclName());
|
auto stub = std::make_shared<NcclStub>(DefaultNcclName());
|
||||||
auto rc = GetNCCLResult(stub, ncclSystemError);
|
auto rc = stub->GetNcclResult(ncclSystemError);
|
||||||
auto msg = rc.Report();
|
auto msg = rc.Report();
|
||||||
ASSERT_TRUE(msg.find("environment variables") != std::string::npos);
|
ASSERT_TRUE(msg.find("environment variables") != std::string::npos);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user