diff --git a/src/collective/allgather.cc b/src/collective/allgather.cc index a51b79fbc..fa369a9da 100644 --- a/src/collective/allgather.cc +++ b/src/collective/allgather.cc @@ -7,20 +7,23 @@ #include // for size_t #include // for int8_t, int32_t, int64_t #include // for shared_ptr -#include // for partial_sum -#include // for vector +#include "broadcast.h" #include "comm.h" // for Comm, Channel #include "xgboost/collective/result.h" // for Result #include "xgboost/span.h" // for Span -namespace xgboost::collective::cpu_impl { +namespace xgboost::collective { +namespace cpu_impl { Result RingAllgather(Comm const& comm, common::Span data, std::size_t segment_size, std::int32_t worker_off, std::shared_ptr prev_ch, std::shared_ptr next_ch) { auto world = comm.World(); auto rank = comm.Rank(); CHECK_LT(worker_off, world); + if (world == 1) { + return Success(); + } for (std::int32_t r = 0; r < world; ++r) { auto send_rank = (rank + world - r + worker_off) % world; @@ -43,11 +46,29 @@ Result RingAllgather(Comm const& comm, common::Span data, std::size return Success(); } +Result BroadcastAllgatherV(Comm const& comm, common::Span sizes, + common::Span recv) { + std::size_t offset = 0; + for (std::int32_t r = 0; r < comm.World(); ++r) { + auto as_bytes = sizes[r]; + auto rc = Broadcast(comm, recv.subspan(offset, as_bytes), r); + if (!rc.OK()) { + return rc; + } + offset += as_bytes; + } + return Success(); +} +} // namespace cpu_impl + +namespace detail { [[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span sizes, - common::Span data, - common::Span offset, + common::Span offset, common::Span erased_result) { auto world = comm.World(); + if (world == 1) { + return Success(); + } auto rank = comm.Rank(); auto prev = BootstrapPrev(rank, comm.World()); @@ -56,17 +77,6 @@ Result RingAllgather(Comm const& comm, common::Span data, std::size auto prev_ch = comm.Chan(prev); auto next_ch = comm.Chan(next); - // get worker offset - CHECK_EQ(world + 1, offset.size()); - std::fill_n(offset.data(), offset.size(), 0); - std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1); - CHECK_EQ(*offset.cbegin(), 0); - - // copy data - auto current = erased_result.subspan(offset[rank], data.size_bytes()); - auto erased_data = EraseType(data); - std::copy_n(erased_data.data(), erased_data.size(), current.data()); - for (std::int32_t r = 0; r < world; ++r) { auto send_rank = (rank + world - r) % world; auto send_off = offset[send_rank]; @@ -87,4 +97,5 @@ Result RingAllgather(Comm const& comm, common::Span data, std::size } return comm.Block(); } -} // namespace xgboost::collective::cpu_impl +} // namespace detail +} // namespace xgboost::collective diff --git a/src/collective/allgather.h b/src/collective/allgather.h index a566da78d..4f13014be 100644 --- a/src/collective/allgather.h +++ b/src/collective/allgather.h @@ -9,28 +9,47 @@ #include // for remove_cv_t #include // for vector -#include "../common/type.h" // for EraseType +#include "../common/type.h" // for EraseType #include "comm.h" // for Comm, Channel #include "xgboost/collective/result.h" // for Result -#include "xgboost/span.h" // for Span +#include "xgboost/linalg.h" +#include "xgboost/span.h" // for Span namespace xgboost::collective { namespace cpu_impl { /** - * @param worker_off Segment offset. For example, if the rank 2 worker specifis worker_off - * = 1, then it owns the third segment. + * @param worker_off Segment offset. For example, if the rank 2 worker specifies + * worker_off = 1, then it owns the third segment. */ [[nodiscard]] Result RingAllgather(Comm const& comm, common::Span data, std::size_t segment_size, std::int32_t worker_off, std::shared_ptr prev_ch, std::shared_ptr next_ch); -[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span sizes, - common::Span data, - common::Span offset, - common::Span erased_result); +/** + * @brief Implement allgather-v using broadcast. + * + * https://arxiv.org/abs/1812.05964 + */ +Result BroadcastAllgatherV(Comm const& comm, common::Span sizes, + common::Span recv); } // namespace cpu_impl +namespace detail { +inline void AllgatherVOffset(common::Span sizes, + common::Span offset) { + // get worker offset + std::fill_n(offset.data(), offset.size(), 0); + std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1); + CHECK_EQ(*offset.cbegin(), 0); +} + +// An implementation that's used by both cpu and gpu +[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span sizes, + common::Span offset, + common::Span erased_result); +} // namespace detail + template [[nodiscard]] Result RingAllgather(Comm const& comm, common::Span data, std::size_t size) { auto n_bytes = sizeof(T) * size; @@ -68,9 +87,15 @@ template auto h_result = common::Span{result.data(), result.size()}; auto erased_result = common::EraseType(h_result); auto erased_data = common::EraseType(data); - std::vector offset(world + 1); + std::vector recv_segments(world + 1); + auto s_segments = common::Span{recv_segments.data(), recv_segments.size()}; - return cpu_impl::RingAllgatherV(comm, sizes, erased_data, - common::Span{offset.data(), offset.size()}, erased_result); + // get worker offset + detail::AllgatherVOffset(sizes, s_segments); + // copy data + auto current = erased_result.subspan(recv_segments[rank], data.size_bytes()); + std::copy_n(erased_data.data(), erased_data.size(), current.data()); + + return detail::RingAllgatherV(comm, sizes, s_segments, erased_result); } } // namespace xgboost::collective diff --git a/src/collective/coll.cc b/src/collective/coll.cc index 6682e57ff..598e6129d 100644 --- a/src/collective/coll.cc +++ b/src/collective/coll.cc @@ -8,16 +8,14 @@ #include // for int8_t, int64_t #include // for bit_and, bit_or, bit_xor, plus -#include "allgather.h" // for RingAllgatherV, RingAllgather -#include "allreduce.h" // for Allreduce -#include "broadcast.h" // for Broadcast -#include "comm.h" // for Comm -#include "xgboost/context.h" // for Context +#include "allgather.h" // for RingAllgatherV, RingAllgather +#include "allreduce.h" // for Allreduce +#include "broadcast.h" // for Broadcast +#include "comm.h" // for Comm namespace xgboost::collective { -[[nodiscard]] Result Coll::Allreduce(Context const*, Comm const& comm, - common::Span data, ArrayInterfaceHandler::Type, - Op op) { +[[nodiscard]] Result Coll::Allreduce(Comm const& comm, common::Span data, + ArrayInterfaceHandler::Type, Op op) { namespace coll = ::xgboost::collective; auto redop_fn = [](auto lhs, auto out, auto elem_op) { @@ -55,21 +53,45 @@ namespace xgboost::collective { return comm.Block(); } -[[nodiscard]] Result Coll::Broadcast(Context const*, Comm const& comm, - common::Span data, std::int32_t root) { +[[nodiscard]] Result Coll::Broadcast(Comm const& comm, common::Span data, + std::int32_t root) { return cpu_impl::Broadcast(comm, data, root); } -[[nodiscard]] Result Coll::Allgather(Context const*, Comm const& comm, - common::Span data, std::size_t size) { +[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span data, + std::int64_t size) { return RingAllgather(comm, data, size); } -[[nodiscard]] Result Coll::AllgatherV(Context const*, Comm const& comm, - common::Span data, +[[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span data, common::Span sizes, common::Span recv_segments, - common::Span recv) { - return cpu_impl::RingAllgatherV(comm, sizes, data, recv_segments, recv); + common::Span recv, AllgatherVAlgo algo) { + // get worker offset + detail::AllgatherVOffset(sizes, recv_segments); + + // copy data + auto current = recv.subspan(recv_segments[comm.Rank()], data.size_bytes()); + if (current.data() != data.data()) { + std::copy_n(data.data(), data.size(), current.data()); + } + + switch (algo) { + case AllgatherVAlgo::kRing: + return detail::RingAllgatherV(comm, sizes, recv_segments, recv); + case AllgatherVAlgo::kBcast: + return cpu_impl::BroadcastAllgatherV(comm, sizes, recv); + default: { + return Fail("Unknown algorithm for allgather-v"); + } + } } + +#if !defined(XGBOOST_USE_NCCL) +Coll* Coll::MakeCUDAVar() { + LOG(FATAL) << "NCCL is required for device communication."; + return nullptr; +} +#endif + } // namespace xgboost::collective diff --git a/src/collective/coll.cu b/src/collective/coll.cu new file mode 100644 index 000000000..bac9fb094 --- /dev/null +++ b/src/collective/coll.cu @@ -0,0 +1,254 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#if defined(XGBOOST_USE_NCCL) +#include // for int8_t, int64_t + +#include "../common/cuda_context.cuh" +#include "../common/device_helpers.cuh" +#include "../data/array_interface.h" +#include "allgather.h" // for AllgatherVOffset +#include "coll.cuh" +#include "comm.cuh" +#include "nccl.h" +#include "xgboost/collective/result.h" // for Result +#include "xgboost/span.h" // for Span + +namespace xgboost::collective { +Coll* Coll::MakeCUDAVar() { return new NCCLColl{}; } + +NCCLColl::~NCCLColl() = default; +namespace { +Result GetNCCLResult(ncclResult_t code) { + if (code == ncclSuccess) { + return Success(); + } + + std::stringstream ss; + ss << "NCCL failure: " << ncclGetErrorString(code) << "."; + if (code == ncclUnhandledCudaError) { + // nccl usually preserves the last error so we can get more details. + auto err = cudaPeekAtLastError(); + ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n"; + } else if (code == ncclSystemError) { + ss << " This might be caused by a network configuration issue. Please consider specifying " + "the network interface for NCCL via environment variables listed in its reference: " + "`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n"; + } + return Fail(ss.str()); +} + +auto GetNCCLType(ArrayInterfaceHandler::Type type) { + auto fatal = [] { + LOG(FATAL) << "Invalid type for NCCL operation."; + return ncclHalf; // dummy return to silent the compiler warning. + }; + using H = ArrayInterfaceHandler; + switch (type) { + case H::kF2: + return ncclHalf; + case H::kF4: + return ncclFloat32; + case H::kF8: + return ncclFloat64; + case H::kF16: + return fatal(); + case H::kI1: + return ncclInt8; + case H::kI2: + return fatal(); + case H::kI4: + return ncclInt32; + case H::kI8: + return ncclInt64; + case H::kU1: + return ncclUint8; + case H::kU2: + return fatal(); + case H::kU4: + return ncclUint32; + case H::kU8: + return ncclUint64; + } + return fatal(); +} + +bool IsBitwiseOp(Op const& op) { + return op == Op::kBitwiseAND || op == Op::kBitwiseOR || op == Op::kBitwiseXOR; +} + +template +void RunBitwiseAllreduce(dh::CUDAStreamView stream, common::Span out_buffer, + std::int8_t const* device_buffer, Func func, std::int32_t world_size, + std::size_t size) { + dh::LaunchN(size, stream, [=] __device__(std::size_t idx) { + auto result = device_buffer[idx]; + for (auto rank = 1; rank < world_size; rank++) { + result = func(result, device_buffer[rank * size + idx]); + } + out_buffer[idx] = result; + }); +} + +[[nodiscard]] Result BitwiseAllReduce(NCCLComm const* pcomm, ncclComm_t handle, + common::Span data, Op op) { + dh::device_vector buffer(data.size() * pcomm->World()); + auto* device_buffer = buffer.data().get(); + + // First gather data from all the workers. + CHECK(handle); + auto rc = GetNCCLResult( + ncclAllGather(data.data(), device_buffer, data.size(), ncclInt8, handle, pcomm->Stream())); + if (!rc.OK()) { + return rc; + } + + // Then reduce locally. + switch (op) { + case Op::kBitwiseAND: + RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_and(), + pcomm->World(), data.size()); + break; + case Op::kBitwiseOR: + RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_or(), + pcomm->World(), data.size()); + break; + case Op::kBitwiseXOR: + RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_xor(), + pcomm->World(), data.size()); + break; + default: + LOG(FATAL) << "Not a bitwise reduce operation."; + } + return Success(); +} + +ncclRedOp_t GetNCCLRedOp(Op const& op) { + ncclRedOp_t result{ncclMax}; + switch (op) { + case Op::kMax: + result = ncclMax; + break; + case Op::kMin: + result = ncclMin; + break; + case Op::kSum: + result = ncclSum; + break; + default: + LOG(FATAL) << "Unsupported reduce operation."; + } + return result; +} +} // namespace + +[[nodiscard]] Result NCCLColl::Allreduce(Comm const& comm, common::Span data, + ArrayInterfaceHandler::Type type, Op op) { + if (!comm.IsDistributed()) { + return Success(); + } + auto nccl = dynamic_cast(&comm); + CHECK(nccl); + return Success() << [&] { + if (IsBitwiseOp(op)) { + return BitwiseAllReduce(nccl, nccl->Handle(), data, op); + } else { + return DispatchDType(type, [=](auto t) { + using T = decltype(t); + auto rdata = common::RestoreType(data); + auto rc = ncclAllReduce(data.data(), data.data(), rdata.size(), GetNCCLType(type), + GetNCCLRedOp(op), nccl->Handle(), nccl->Stream()); + return GetNCCLResult(rc); + }); + } + } << [&] { return nccl->Block(); }; +} + +[[nodiscard]] Result NCCLColl::Broadcast(Comm const& comm, common::Span data, + std::int32_t root) { + if (!comm.IsDistributed()) { + return Success(); + } + auto nccl = dynamic_cast(&comm); + CHECK(nccl); + return Success() << [&] { + return GetNCCLResult(ncclBroadcast(data.data(), data.data(), data.size_bytes(), ncclInt8, root, + nccl->Handle(), nccl->Stream())); + } << [&] { return nccl->Block(); }; +} + +[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span data, + std::int64_t size) { + if (!comm.IsDistributed()) { + return Success(); + } + auto nccl = dynamic_cast(&comm); + CHECK(nccl); + auto send = data.subspan(comm.Rank() * size, size); + return Success() << [&] { + return GetNCCLResult( + ncclAllGather(send.data(), data.data(), size, ncclInt8, nccl->Handle(), nccl->Stream())); + } << [&] { return nccl->Block(); }; +} + +namespace cuda_impl { +/** + * @brief Implement allgather-v using broadcast. + * + * https://arxiv.org/abs/1812.05964 + */ +Result BroadcastAllgatherV(NCCLComm const* comm, common::Span data, + common::Span sizes, common::Span recv) { + return Success() << [] { return GetNCCLResult(ncclGroupStart()); } << [&] { + std::size_t offset = 0; + for (std::int32_t r = 0; r < comm->World(); ++r) { + auto as_bytes = sizes[r]; + auto rc = ncclBroadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes, + ncclInt8, r, comm->Handle(), dh::DefaultStream()); + if (rc != ncclSuccess) { + return GetNCCLResult(rc); + } + offset += as_bytes; + } + return Success(); + } << [] { return GetNCCLResult(ncclGroupEnd()); }; +} +} // namespace cuda_impl + +[[nodiscard]] Result NCCLColl::AllgatherV(Comm const& comm, common::Span data, + common::Span sizes, + common::Span recv_segments, + common::Span recv, AllgatherVAlgo algo) { + auto nccl = dynamic_cast(&comm); + CHECK(nccl); + if (!comm.IsDistributed()) { + return Success(); + } + + switch (algo) { + case AllgatherVAlgo::kRing: { + return Success() << [] { return GetNCCLResult(ncclGroupStart()); } << [&] { + // get worker offset + detail::AllgatherVOffset(sizes, recv_segments); + // copy data + auto current = recv.subspan(recv_segments[comm.Rank()], data.size_bytes()); + if (current.data() != data.data()) { + dh::safe_cuda(cudaMemcpyAsync(current.data(), data.data(), current.size_bytes(), + cudaMemcpyDeviceToDevice, nccl->Stream())); + } + return detail::RingAllgatherV(comm, sizes, recv_segments, recv); + } << [] { + return GetNCCLResult(ncclGroupEnd()); + } << [&] { return nccl->Block(); }; + } + case AllgatherVAlgo::kBcast: { + return cuda_impl::BroadcastAllgatherV(nccl, data, sizes, recv); + } + default: { + return Fail("Unknown algorithm for allgather-v"); + } + } +} +} // namespace xgboost::collective + +#endif // defined(XGBOOST_USE_NCCL) diff --git a/src/collective/coll.cuh b/src/collective/coll.cuh new file mode 100644 index 000000000..87fb46711 --- /dev/null +++ b/src/collective/coll.cuh @@ -0,0 +1,29 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#pragma once + +#include // for int8_t, int64_t + +#include "../data/array_interface.h" // for ArrayInterfaceHandler +#include "coll.h" // for Coll +#include "comm.h" // for Comm +#include "xgboost/span.h" // for Span + +namespace xgboost::collective { +class NCCLColl : public Coll { + public: + ~NCCLColl() override; + + [[nodiscard]] Result Allreduce(Comm const& comm, common::Span data, + ArrayInterfaceHandler::Type type, Op op) override; + [[nodiscard]] Result Broadcast(Comm const& comm, common::Span data, + std::int32_t root) override; + [[nodiscard]] Result Allgather(Comm const& comm, common::Span data, + std::int64_t size) override; + [[nodiscard]] Result AllgatherV(Comm const& comm, common::Span data, + common::Span sizes, + common::Span recv_segments, + common::Span recv, AllgatherVAlgo algo) override; +}; +} // namespace xgboost::collective diff --git a/src/collective/coll.h b/src/collective/coll.h index 9a318db8d..0189ffd5e 100644 --- a/src/collective/coll.h +++ b/src/collective/coll.h @@ -2,17 +2,20 @@ * Copyright 2023, XGBoost Contributors */ #pragma once -#include // for size_t #include // for int8_t, int64_t #include // for enable_shared_from_this #include "../data/array_interface.h" // for ArrayInterfaceHandler #include "comm.h" // for Comm #include "xgboost/collective/result.h" // for Result -#include "xgboost/context.h" // for Context #include "xgboost/span.h" // for Span namespace xgboost::collective { +enum class AllgatherVAlgo { + kRing = 0, // use ring-based allgather-v + kBcast = 1, // use broadcast-based allgather-v +}; + /** * @brief Interface and base implementation for collective. */ @@ -21,6 +24,8 @@ class Coll : public std::enable_shared_from_this { Coll() = default; virtual ~Coll() noexcept(false) {} // NOLINT + Coll* MakeCUDAVar(); + /** * @brief Allreduce * @@ -29,8 +34,7 @@ class Coll : public std::enable_shared_from_this { * @param [in] op Reduce operation. For custom operation, user needs to reach down to * the CPU implementation. */ - [[nodiscard]] virtual Result Allreduce(Context const* ctx, Comm const& comm, - common::Span data, + [[nodiscard]] virtual Result Allreduce(Comm const& comm, common::Span data, ArrayInterfaceHandler::Type type, Op op); /** * @brief Broadcast @@ -38,29 +42,29 @@ class Coll : public std::enable_shared_from_this { * @param [in,out] data Data buffer for input and output. * @param [in] root Root rank for broadcast. */ - [[nodiscard]] virtual Result Broadcast(Context const* ctx, Comm const& comm, - common::Span data, std::int32_t root); + [[nodiscard]] virtual Result Broadcast(Comm const& comm, common::Span data, + std::int32_t root); /** * @brief Allgather * * @param [in,out] data Data buffer for input and output. * @param [in] size Size of data for each worker. */ - [[nodiscard]] virtual Result Allgather(Context const* ctx, Comm const& comm, - common::Span data, std::size_t size); + [[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span data, + std::int64_t size); /** * @brief Allgather with variable length. * * @param [in] data Input data for the current worker. * @param [in] sizes Size of the input from each worker. - * @param [out] recv_segments pre-allocated offset for each worker in the output, size - * should be equal to (world + 1). + * @param [out] recv_segments pre-allocated offset buffer for each worker in the output, + * size should be equal to (world + 1). GPU ring-based implementation + * doesn't use the buffer. * @param [out] recv pre-allocated buffer for output. */ - [[nodiscard]] virtual Result AllgatherV(Context const* ctx, Comm const& comm, - common::Span data, + [[nodiscard]] virtual Result AllgatherV(Comm const& comm, common::Span data, common::Span sizes, common::Span recv_segments, - common::Span recv); + common::Span recv, AllgatherVAlgo algo); }; } // namespace xgboost::collective diff --git a/src/collective/comm.cc b/src/collective/comm.cc index 3c49303fa..dbd45cbb2 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -262,7 +262,7 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se } RabitComm::~RabitComm() noexcept(false) { - if (!IsDistributed()) { + if (!this->IsDistributed()) { return; } auto rc = this->Shutdown(); diff --git a/src/collective/comm.cu b/src/collective/comm.cu new file mode 100644 index 000000000..31a06e124 --- /dev/null +++ b/src/collective/comm.cu @@ -0,0 +1,112 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#if defined(XGBOOST_USE_NCCL) +#include // for sort +#include // for size_t +#include // for uint64_t, int8_t +#include // for memcpy +#include // for shared_ptr +#include // for stringstream +#include // for vector + +#include "../common/device_helpers.cuh" // for DefaultStream +#include "../common/type.h" // for EraseType +#include "broadcast.h" // for Broadcast +#include "comm.cuh" // for NCCLComm +#include "comm.h" // for Comm +#include "xgboost/collective/result.h" // for Result +#include "xgboost/span.h" // for Span + +namespace xgboost::collective { +namespace { +Result GetUniqueId(Comm const& comm, ncclUniqueId* pid) { + static const int kRootRank = 0; + ncclUniqueId id; + if (comm.Rank() == kRootRank) { + dh::safe_nccl(ncclGetUniqueId(&id)); + } + auto rc = Broadcast(comm, common::Span{reinterpret_cast(&id), sizeof(ncclUniqueId)}, + kRootRank); + if (!rc.OK()) { + return rc; + } + *pid = id; + return Success(); +} + +inline constexpr std::size_t kUuidLength = + sizeof(std::declval().uuid) / sizeof(std::uint64_t); + +void GetCudaUUID(xgboost::common::Span const& uuid, DeviceOrd device) { + cudaDeviceProp prob{}; + dh::safe_cuda(cudaGetDeviceProperties(&prob, device.ordinal)); + std::memcpy(uuid.data(), static_cast(&(prob.uuid)), sizeof(prob.uuid)); +} + +static std::string PrintUUID(xgboost::common::Span const& uuid) { + std::stringstream ss; + for (auto v : uuid) { + ss << std::hex << v; + } + return ss.str(); +} +} // namespace + +Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr pimpl) { + return new NCCLComm{ctx, *this, pimpl}; +} + +NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr pimpl) + : Comm{root.TrackerInfo().host, root.TrackerInfo().port, root.Timeout(), root.Retry(), + root.TaskID()}, + stream_{dh::DefaultStream()} { + this->world_ = root.World(); + this->rank_ = root.Rank(); + this->domain_ = root.Domain(); + if (!root.IsDistributed()) { + return; + } + + dh::safe_cuda(cudaSetDevice(ctx->Ordinal())); + + std::vector uuids(root.World() * kUuidLength, 0); + auto s_uuid = xgboost::common::Span{uuids.data(), uuids.size()}; + auto s_this_uuid = s_uuid.subspan(root.Rank() * kUuidLength, kUuidLength); + GetCudaUUID(s_this_uuid, ctx->Device()); + + auto rc = pimpl->Allgather(root, common::EraseType(s_uuid), s_this_uuid.size_bytes()); + CHECK(rc.OK()) << rc.Report(); + + std::vector> converted(root.World()); + std::size_t j = 0; + for (size_t i = 0; i < uuids.size(); i += kUuidLength) { + converted[j] = s_uuid.subspan(i, kUuidLength); + j++; + } + + std::sort(converted.begin(), converted.end()); + auto iter = std::unique(converted.begin(), converted.end()); + auto n_uniques = std::distance(converted.begin(), iter); + + CHECK_EQ(n_uniques, root.World()) + << "Multiple processes within communication group running on same CUDA " + << "device is not supported. " << PrintUUID(s_this_uuid) << "\n"; + + rc = GetUniqueId(root, &nccl_unique_id_); + CHECK(rc.OK()) << rc.Report(); + dh::safe_nccl(ncclCommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank())); + + for (std::int32_t r = 0; r < root.World(); ++r) { + this->channels_.emplace_back( + std::make_shared(root, r, nccl_comm_, dh::DefaultStream())); + } +} + +NCCLComm::~NCCLComm() { + if (nccl_comm_) { + dh::safe_nccl(ncclCommDestroy(nccl_comm_)); + } +} +} // namespace xgboost::collective +#endif // defined(XGBOOST_USE_NCCL) diff --git a/src/collective/comm.cuh b/src/collective/comm.cuh new file mode 100644 index 000000000..ea15c50f3 --- /dev/null +++ b/src/collective/comm.cuh @@ -0,0 +1,67 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#pragma once + +#ifdef XGBOOST_USE_NCCL +#include "nccl.h" +#endif // XGBOOST_USE_NCCL +#include "../common/device_helpers.cuh" +#include "coll.h" +#include "comm.h" +#include "xgboost/context.h" + +namespace xgboost::collective { + +inline Result GetCUDAResult(cudaError rc) { + if (rc == cudaSuccess) { + return Success(); + } + std::string msg = thrust::system_error(rc, thrust::cuda_category()).what(); + return Fail(msg); +} + +class NCCLComm : public Comm { + ncclComm_t nccl_comm_{nullptr}; + ncclUniqueId nccl_unique_id_{}; + dh::CUDAStreamView stream_; + + public: + [[nodiscard]] ncclComm_t Handle() const { return nccl_comm_; } + + explicit NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr pimpl); + [[nodiscard]] Result LogTracker(std::string) const override { + LOG(FATAL) << "Device comm is used for logging."; + return Fail("Undefined."); + } + ~NCCLComm() override; + [[nodiscard]] bool IsFederated() const override { return false; } + [[nodiscard]] dh::CUDAStreamView Stream() const { return stream_; } + [[nodiscard]] Result Block() const override { + auto rc = this->Stream().Sync(false); + return GetCUDAResult(rc); + } +}; + +class NCCLChannel : public Channel { + std::int32_t rank_{-1}; + ncclComm_t nccl_comm_{}; + dh::CUDAStreamView stream_; + + public: + explicit NCCLChannel(Comm const& comm, std::int32_t rank, ncclComm_t nccl_comm, + dh::CUDAStreamView stream) + : rank_{rank}, nccl_comm_{nccl_comm}, Channel{comm, nullptr}, stream_{stream} {} + + void SendAll(std::int8_t const* ptr, std::size_t n) override { + dh::safe_nccl(ncclSend(ptr, n, ncclInt8, rank_, nccl_comm_, stream_)); + } + void RecvAll(std::int8_t* ptr, std::size_t n) override { + dh::safe_nccl(ncclRecv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_)); + } + [[nodiscard]] Result Block() override { + auto rc = stream_.Sync(false); + return GetCUDAResult(rc); + } +}; +} // namespace xgboost::collective diff --git a/src/collective/comm.h b/src/collective/comm.h index adf23b9e4..afb543c46 100644 --- a/src/collective/comm.h +++ b/src/collective/comm.h @@ -2,20 +2,20 @@ * Copyright 2023, XGBoost Contributors */ #pragma once -#include // for seconds -#include // for size_t -#include // for int32_t -#include // for shared_ptr -#include // for string -#include // for thread -#include // for remove_const_t -#include // for move -#include // for vector +#include // for seconds +#include // for size_t +#include // for int32_t +#include // for shared_ptr +#include // for string +#include // for thread +#include // for move +#include // for vector #include "loop.h" // for Loop #include "protocol.h" // for PeerInfo #include "xgboost/collective/result.h" // for Result #include "xgboost/collective/socket.h" // for TCPSocket +#include "xgboost/context.h" // for Context #include "xgboost/span.h" // for Span namespace xgboost::collective { @@ -35,13 +35,14 @@ inline std::int32_t BootstrapPrev(std::int32_t r, std::int32_t world) { } class Channel; +class Coll; /** * @brief Base communicator storing info about the tracker and other communicators. */ class Comm { protected: - std::int32_t world_{1}; + std::int32_t world_{-1}; std::int32_t rank_{0}; std::chrono::seconds timeout_{DefaultTimeoutSec()}; std::int32_t retry_{DefaultRetry()}; @@ -69,12 +70,14 @@ class Comm { [[nodiscard]] Result ConnectTracker(TCPSocket* out) const; [[nodiscard]] auto Domain() const { return domain_; } [[nodiscard]] auto Timeout() const { return timeout_; } + [[nodiscard]] auto Retry() const { return retry_; } + [[nodiscard]] auto TaskID() const { return task_id_; } [[nodiscard]] auto Rank() const { return rank_; } - [[nodiscard]] auto World() const { return world_; } - [[nodiscard]] bool IsDistributed() const { return World() > 1; } + [[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; } + [[nodiscard]] bool IsDistributed() const { return world_ != -1; } void Submit(Loop::Op op) const { loop_->Submit(op); } - [[nodiscard]] Result Block() const { return loop_->Block(); } + [[nodiscard]] virtual Result Block() const { return loop_->Block(); } [[nodiscard]] virtual std::shared_ptr Chan(std::int32_t rank) const { return channels_.at(rank); @@ -83,6 +86,8 @@ class Comm { [[nodiscard]] virtual Result LogTracker(std::string msg) const = 0; [[nodiscard]] virtual Result SignalError(Result const&) { return Success(); } + + Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr pimpl); }; class RabitComm : public Comm { @@ -116,7 +121,7 @@ class Channel { explicit Channel(Comm const& comm, std::shared_ptr sock) : sock_{std::move(sock)}, comm_{comm} {} - void SendAll(std::int8_t const* ptr, std::size_t n) { + virtual void SendAll(std::int8_t const* ptr, std::size_t n) { Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast(ptr), n, sock_.get(), 0}; CHECK(sock_.get()); comm_.Submit(std::move(op)); @@ -125,7 +130,7 @@ class Channel { this->SendAll(data.data(), data.size_bytes()); } - void RecvAll(std::int8_t* ptr, std::size_t n) { + virtual void RecvAll(std::int8_t* ptr, std::size_t n) { Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0}; CHECK(sock_.get()); comm_.Submit(std::move(op)); @@ -133,7 +138,7 @@ class Channel { void RecvAll(common::Span data) { this->RecvAll(data.data(), data.size_bytes()); } [[nodiscard]] auto Socket() const { return sock_; } - [[nodiscard]] Result Block() { return comm_.Block(); } + [[nodiscard]] virtual Result Block() { return comm_.Block(); } }; enum class Op { kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5 }; diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index dfaac9c35..89b3ad2e6 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -1169,7 +1169,13 @@ class CUDAStreamView { operator cudaStream_t() const { // NOLINT return stream_; } - void Sync() { dh::safe_cuda(cudaStreamSynchronize(stream_)); } + cudaError_t Sync(bool error = true) { + if (error) { + dh::safe_cuda(cudaStreamSynchronize(stream_)); + return cudaSuccess; + } + return cudaStreamSynchronize(stream_); + } }; inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 8fd3120b5..6db201dd5 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -19,7 +19,6 @@ #include "../common/cuda_context.cuh" // CUDAContext #include "../common/device_helpers.cuh" #include "../common/hist_util.h" -#include "../common/io.h" #include "../common/timer.h" #include "../data/ellpack_page.cuh" #include "../data/ellpack_page.h" @@ -39,7 +38,6 @@ #include "xgboost/data.h" #include "xgboost/host_device_vector.h" #include "xgboost/json.h" -#include "xgboost/parameter.h" #include "xgboost/span.h" #include "xgboost/task.h" // for ObjInfo #include "xgboost/tree_model.h" diff --git a/tests/cpp/collective/test_allgather.cc b/tests/cpp/collective/test_allgather.cc index a74b9f149..bdfadc0c7 100644 --- a/tests/cpp/collective/test_allgather.cc +++ b/tests/cpp/collective/test_allgather.cc @@ -14,6 +14,7 @@ #include // for vector #include "../../../src/collective/allgather.h" // for RingAllgather +#include "../../../src/collective/coll.h" // for Coll #include "../../../src/collective/comm.h" // for RabitComm #include "gtest/gtest.h" // for AssertionR... #include "test_worker.h" // for TestDistri... @@ -63,37 +64,79 @@ class Worker : public WorkerForTest { } } - void TestV() { - { - // basic test - std::int32_t n{comm_.Rank()}; - std::vector result; - auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result); - ASSERT_TRUE(rc.OK()) << rc.Report(); - for (std::int32_t i = 0; i < comm_.World(); ++i) { - ASSERT_EQ(result[i], i); - } - } - - { - // V test - std::vector data(comm_.Rank() + 1, comm_.Rank()); - std::vector result; - auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result); - ASSERT_TRUE(rc.OK()) << rc.Report(); - ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2); - std::int32_t k{0}; - for (std::int32_t r = 0; r < comm_.World(); ++r) { - auto seg = common::Span{result.data(), result.size()}.subspan(k, (r + 1)); - if (comm_.Rank() == 0) { - for (auto v : seg) { - ASSERT_EQ(v, r); - } - k += seg.size(); + void CheckV(common::Span result) { + std::int32_t k{0}; + for (std::int32_t r = 0; r < comm_.World(); ++r) { + auto seg = common::Span{result.data(), result.size()}.subspan(k, (r + 1)); + if (comm_.Rank() == 0) { + for (auto v : seg) { + ASSERT_EQ(v, r); } + k += seg.size(); } } } + void TestVRing() { + // V test + std::vector data(comm_.Rank() + 1, comm_.Rank()); + std::vector result; + auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result); + ASSERT_TRUE(rc.OK()) << rc.Report(); + ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2); + CheckV(result); + } + + void TestVBasic() { + // basic test + std::int32_t n{comm_.Rank()}; + std::vector result; + auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result); + ASSERT_TRUE(rc.OK()) << rc.Report(); + for (std::int32_t i = 0; i < comm_.World(); ++i) { + ASSERT_EQ(result[i], i); + } + } + + void TestVAlgo() { + // V test, broadcast + std::vector data(comm_.Rank() + 1, comm_.Rank()); + auto s_data = common::Span{data.data(), data.size()}; + + std::vector sizes(comm_.World(), 0); + sizes[comm_.Rank()] = s_data.size_bytes(); + auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1); + ASSERT_TRUE(rc.OK()) << rc.Report(); + std::shared_ptr pcoll{new Coll{}}; + + std::vector recv_segments(comm_.World() + 1, 0); + std::vector recv(std::accumulate(sizes.cbegin(), sizes.cend(), 0)); + + auto s_recv = common::Span{recv.data(), recv.size()}; + + rc = pcoll->AllgatherV(comm_, common::EraseType(s_data), + common::Span{sizes.data(), sizes.size()}, + common::Span{recv_segments.data(), recv_segments.size()}, + common::EraseType(s_recv), AllgatherVAlgo::kBcast); + ASSERT_TRUE(rc.OK()); + CheckV(s_recv); + + // Test inplace + auto test_inplace = [&] (AllgatherVAlgo algo) { + std::fill_n(s_recv.data(), s_recv.size(), 0); + auto current = s_recv.subspan(recv_segments[comm_.Rank()], + recv_segments[comm_.Rank() + 1] - recv_segments[comm_.Rank()]); + std::copy_n(data.data(), data.size(), current.data()); + rc = pcoll->AllgatherV(comm_, common::EraseType(current), + common::Span{sizes.data(), sizes.size()}, + common::Span{recv_segments.data(), recv_segments.size()}, + common::EraseType(s_recv), algo); + ASSERT_TRUE(rc.OK()); + CheckV(s_recv); + }; + + test_inplace(AllgatherVAlgo::kBcast); + test_inplace(AllgatherVAlgo::kRing); + } }; } // namespace @@ -106,12 +149,30 @@ TEST_F(AllgatherTest, Basic) { }); } -TEST_F(AllgatherTest, V) { +TEST_F(AllgatherTest, VBasic) { std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency()); TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) { Worker worker{host, port, timeout, n_workers, r}; - worker.TestV(); + worker.TestVBasic(); + }); +} + +TEST_F(AllgatherTest, VRing) { + std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency()); + TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + Worker worker{host, port, timeout, n_workers, r}; + worker.TestVRing(); + }); +} + +TEST_F(AllgatherTest, VAlgo) { + std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency()); + TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + Worker worker{host, port, timeout, n_workers, r}; + worker.TestVAlgo(); }); } } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_allgather.cu b/tests/cpp/collective/test_allgather.cu new file mode 100644 index 000000000..48f7c2615 --- /dev/null +++ b/tests/cpp/collective/test_allgather.cu @@ -0,0 +1,117 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#if defined(XGBOOST_USE_NCCL) +#include +#include // for device_vector +#include // for equal +#include // for Span + +#include // for size_t +#include // for int32_t, int64_t +#include // for vector + +#include "../../../src/collective/allgather.h" // for RingAllgather +#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector +#include "../../../src/common/type.h" // for EraseType +#include "test_worker.cuh" // for NCCLWorkerForTest +#include "test_worker.h" // for TestDistributed, WorkerForTest + +namespace xgboost::collective { +namespace { +class Worker : public NCCLWorkerForTest { + public: + using NCCLWorkerForTest::NCCLWorkerForTest; + + void TestV(AllgatherVAlgo algo) { + { + // basic test + std::size_t n = 1; + // create data + dh::device_vector data(n, comm_.Rank()); + auto s_data = common::EraseType(common::Span{data.data().get(), data.size()}); + // get size + std::vector sizes(comm_.World(), -1); + sizes[comm_.Rank()] = s_data.size_bytes(); + auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1); + ASSERT_TRUE(rc.OK()) << rc.Report(); + // create result + dh::device_vector result(comm_.World(), -1); + auto s_result = common::EraseType(dh::ToSpan(result)); + + std::vector recv_seg(nccl_comm_->World() + 1, 0); + rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()}, + common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo); + ASSERT_TRUE(rc.OK()) << rc.Report(); + + for (std::int32_t i = 0; i < comm_.World(); ++i) { + ASSERT_EQ(result[i], i); + } + } + { + // V test + std::size_t n = 256 * 256; + // create data + dh::device_vector data(n * nccl_comm_->Rank(), nccl_comm_->Rank()); + auto s_data = common::EraseType(common::Span{data.data().get(), data.size()}); + // get size + std::vector sizes(nccl_comm_->World(), 0); + sizes[comm_.Rank()] = dh::ToSpan(data).size_bytes(); + auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1); + ASSERT_TRUE(rc.OK()) << rc.Report(); + auto n_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0); + // create result + dh::device_vector result(n_bytes / sizeof(std::int32_t), -1); + auto s_result = common::EraseType(dh::ToSpan(result)); + + std::vector recv_seg(nccl_comm_->World() + 1, 0); + rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()}, + common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo); + ASSERT_TRUE(rc.OK()) << rc.Report(); + // check segment size + if (algo != AllgatherVAlgo::kBcast) { + auto size = recv_seg[nccl_comm_->Rank() + 1] - recv_seg[nccl_comm_->Rank()]; + ASSERT_EQ(size, n * nccl_comm_->Rank() * sizeof(std::int32_t)); + ASSERT_EQ(size, sizes[nccl_comm_->Rank()]); + } + // check data + std::size_t k{0}; + for (std::int32_t r = 0; r < nccl_comm_->World(); ++r) { + std::size_t s = n * r; + auto current = dh::ToSpan(result).subspan(k, s); + std::vector h_data(current.size()); + dh::CopyDeviceSpanToVector(&h_data, current); + for (auto v : h_data) { + ASSERT_EQ(v, r); + } + k += s; + } + } + } +}; + +class AllgatherTestGPU : public SocketTest {}; +} // namespace + +TEST_F(AllgatherTestGPU, MGPUTestVRing) { + auto n_workers = common::AllVisibleGPUs(); + TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + Worker w{host, port, timeout, n_workers, r}; + w.Setup(); + w.TestV(AllgatherVAlgo::kRing); + w.TestV(AllgatherVAlgo::kBcast); + }); +} + +TEST_F(AllgatherTestGPU, MGPUTestVBcast) { + auto n_workers = common::AllVisibleGPUs(); + TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + Worker w{host, port, timeout, n_workers, r}; + w.Setup(); + w.TestV(AllgatherVAlgo::kBcast); + }); +} +} // namespace xgboost::collective +#endif // defined(XGBOOST_USE_NCCL) diff --git a/tests/cpp/collective/test_allreduce.cc b/tests/cpp/collective/test_allreduce.cc index 77d23f6fe..744608dec 100644 --- a/tests/cpp/collective/test_allreduce.cc +++ b/tests/cpp/collective/test_allreduce.cc @@ -6,10 +6,10 @@ #include "../../../src/collective/allreduce.h" #include "../../../src/collective/coll.h" // for Coll #include "../../../src/collective/tracker.h" -#include "test_worker.h" // for WorkerForTest, TestDistributed +#include "../../../src/common/type.h" // for EraseType +#include "test_worker.h" // for WorkerForTest, TestDistributed namespace xgboost::collective { - namespace { class AllreduceWorker : public WorkerForTest { public: @@ -50,11 +50,10 @@ class AllreduceWorker : public WorkerForTest { } void BitOr() { - Context ctx; std::vector data(comm_.World(), 0); data[comm_.Rank()] = ~std::uint32_t{0}; auto pcoll = std::shared_ptr{new Coll{}}; - auto rc = pcoll->Allreduce(&ctx, comm_, EraseType(common::Span{data.data(), data.size()}), + auto rc = pcoll->Allreduce(comm_, common::EraseType(common::Span{data.data(), data.size()}), ArrayInterfaceHandler::kU4, Op::kBitwiseOR); ASSERT_TRUE(rc.OK()) << rc.Report(); for (auto v : data) { diff --git a/tests/cpp/collective/test_allreduce.cu b/tests/cpp/collective/test_allreduce.cu new file mode 100644 index 000000000..af9a4e58f --- /dev/null +++ b/tests/cpp/collective/test_allreduce.cu @@ -0,0 +1,70 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#if defined(XGBOOST_USE_NCCL) +#include +#include // for host_vector + +#include "../../../src/collective/coll.h" // for Coll +#include "../../../src/common/common.h" +#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector +#include "../../../src/common/type.h" // for EraseType +#include "../helpers.h" // for MakeCUDACtx +#include "test_worker.cuh" // for NCCLWorkerForTest +#include "test_worker.h" // for WorkerForTest, TestDistributed + +namespace xgboost::collective { +namespace { +class AllreduceTestGPU : public SocketTest {}; + +class Worker : public NCCLWorkerForTest { + public: + using NCCLWorkerForTest::NCCLWorkerForTest; + + void BitOr() { + dh::device_vector data(comm_.World(), 0); + data[comm_.Rank()] = ~std::uint32_t{0}; + auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)), + ArrayInterfaceHandler::kU4, Op::kBitwiseOR); + ASSERT_TRUE(rc.OK()) << rc.Report(); + thrust::host_vector h_data(data.size()); + thrust::copy(data.cbegin(), data.cend(), h_data.begin()); + for (auto v : h_data) { + ASSERT_EQ(v, ~std::uint32_t{0}); + } + } + + void Acc() { + dh::device_vector data(314, 1.5); + auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)), + ArrayInterfaceHandler::kF8, Op::kSum); + ASSERT_TRUE(rc.OK()) << rc.Report(); + for (std::size_t i = 0; i < data.size(); ++i) { + auto v = data[i]; + ASSERT_EQ(v, 1.5 * static_cast(comm_.World())) << i; + } + } +}; +} // namespace + +TEST_F(AllreduceTestGPU, BitOr) { + auto n_workers = common::AllVisibleGPUs(); + TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + Worker w{host, port, timeout, n_workers, r}; + w.Setup(); + w.BitOr(); + }); +} + +TEST_F(AllreduceTestGPU, Sum) { + auto n_workers = common::AllVisibleGPUs(); + TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + Worker w{host, port, timeout, n_workers, r}; + w.Setup(); + w.Acc(); + }); +} +} // namespace xgboost::collective +#endif // defined(XGBOOST_USE_NCCL) diff --git a/tests/cpp/collective/test_broadcast.cc b/tests/cpp/collective/test_broadcast.cc index 0ade86567..4d0d87e93 100644 --- a/tests/cpp/collective/test_broadcast.cc +++ b/tests/cpp/collective/test_broadcast.cc @@ -47,5 +47,5 @@ TEST_F(BroadcastTest, Basic) { Worker worker{host, port, timeout, n_workers, r}; worker.Run(); }); -} +} // namespace } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_worker.cuh b/tests/cpp/collective/test_worker.cuh new file mode 100644 index 000000000..058f8845e --- /dev/null +++ b/tests/cpp/collective/test_worker.cuh @@ -0,0 +1,32 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#pragma once +#include // for shared_ptr + +#include "../../../src/collective/coll.h" // for Coll +#include "../../../src/collective/comm.h" // for Comm +#include "test_worker.h" +#include "xgboost/context.h" // for Context + +namespace xgboost::collective { +class NCCLWorkerForTest : public WorkerForTest { + protected: + std::shared_ptr coll_; + std::shared_ptr nccl_comm_; + std::shared_ptr nccl_coll_; + Context ctx_; + + public: + using WorkerForTest::WorkerForTest; + + void Setup() { + ctx_ = MakeCUDACtx(comm_.Rank()); + coll_.reset(new Coll{}); + nccl_comm_.reset(this->comm_.MakeCUDAVar(&ctx_, coll_)); + nccl_coll_.reset(coll_->MakeCUDAVar()); + ASSERT_EQ(comm_.World(), nccl_comm_->World()); + ASSERT_EQ(comm_.Rank(), nccl_comm_->Rank()); + } +}; +} // namespace xgboost::collective diff --git a/tests/cpp/collective/test_worker.h b/tests/cpp/collective/test_worker.h index a3d6de875..6578ff142 100644 --- a/tests/cpp/collective/test_worker.h +++ b/tests/cpp/collective/test_worker.h @@ -1,6 +1,7 @@ /** * Copyright 2023, XGBoost Contributors */ +#pragma once #include #include // for seconds