From 0c621094b3b16f63a9f53395c042c3b6d5f10d80 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 27 Oct 2023 16:38:04 -0500 Subject: [PATCH 1/5] [CI] enforce cmakelint checks (#9728) --- .github/workflows/main.yml | 2 +- CMakeLists.txt | 2 +- .../{FindPrefetchIntrinsics.cmake => PrefetchIntrinsics.cmake} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename cmake/{FindPrefetchIntrinsics.cmake => PrefetchIntrinsics.cmake} (100%) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 1f91afdc5..3fd39bc36 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -151,4 +151,4 @@ jobs: python-package/xgboost/lib python-package/xgboost/rabit \ python-package/xgboost/src - sh ./tests/ci_build/lint_cmake.sh || true + sh ./tests/ci_build/lint_cmake.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 3608e5670..e93427ed9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,7 +33,7 @@ elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") endif() endif() -include(${xgboost_SOURCE_DIR}/cmake/FindPrefetchIntrinsics.cmake) +include(${xgboost_SOURCE_DIR}/cmake/PrefetchIntrinsics.cmake) find_prefetch_intrinsics() include(${xgboost_SOURCE_DIR}/cmake/Version.cmake) write_version() diff --git a/cmake/FindPrefetchIntrinsics.cmake b/cmake/PrefetchIntrinsics.cmake similarity index 100% rename from cmake/FindPrefetchIntrinsics.cmake rename to cmake/PrefetchIntrinsics.cmake From 6755179e77f30e1946273cf9fd24ea674fa6219d Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 28 Oct 2023 16:33:58 +0800 Subject: [PATCH 2/5] [coll] Add nccl. (#9726) --- src/collective/allgather.cc | 45 +++-- src/collective/allgather.h | 47 +++-- src/collective/coll.cc | 54 ++++-- src/collective/coll.cu | 254 +++++++++++++++++++++++++ src/collective/coll.cuh | 29 +++ src/collective/coll.h | 30 +-- src/collective/comm.cc | 2 +- src/collective/comm.cu | 112 +++++++++++ src/collective/comm.cuh | 67 +++++++ src/collective/comm.h | 37 ++-- src/common/device_helpers.cuh | 8 +- src/tree/updater_gpu_hist.cu | 2 - tests/cpp/collective/test_allgather.cc | 119 +++++++++--- tests/cpp/collective/test_allgather.cu | 117 ++++++++++++ tests/cpp/collective/test_allreduce.cc | 7 +- tests/cpp/collective/test_allreduce.cu | 70 +++++++ tests/cpp/collective/test_broadcast.cc | 2 +- tests/cpp/collective/test_worker.cuh | 32 ++++ tests/cpp/collective/test_worker.h | 1 + 19 files changed, 924 insertions(+), 111 deletions(-) create mode 100644 src/collective/coll.cu create mode 100644 src/collective/coll.cuh create mode 100644 src/collective/comm.cu create mode 100644 src/collective/comm.cuh create mode 100644 tests/cpp/collective/test_allgather.cu create mode 100644 tests/cpp/collective/test_allreduce.cu create mode 100644 tests/cpp/collective/test_worker.cuh diff --git a/src/collective/allgather.cc b/src/collective/allgather.cc index a51b79fbc..fa369a9da 100644 --- a/src/collective/allgather.cc +++ b/src/collective/allgather.cc @@ -7,20 +7,23 @@ #include // for size_t #include // for int8_t, int32_t, int64_t #include // for shared_ptr -#include // for partial_sum -#include // for vector +#include "broadcast.h" #include "comm.h" // for Comm, Channel #include "xgboost/collective/result.h" // for Result #include "xgboost/span.h" // for Span -namespace xgboost::collective::cpu_impl { +namespace xgboost::collective { +namespace cpu_impl { Result RingAllgather(Comm const& comm, common::Span data, std::size_t segment_size, std::int32_t worker_off, std::shared_ptr prev_ch, std::shared_ptr next_ch) { auto world = comm.World(); auto rank = comm.Rank(); CHECK_LT(worker_off, world); + if (world == 1) { + return Success(); + } for (std::int32_t r = 0; r < world; ++r) { auto send_rank = (rank + world - r + worker_off) % world; @@ -43,11 +46,29 @@ Result RingAllgather(Comm const& comm, common::Span data, std::size return Success(); } +Result BroadcastAllgatherV(Comm const& comm, common::Span sizes, + common::Span recv) { + std::size_t offset = 0; + for (std::int32_t r = 0; r < comm.World(); ++r) { + auto as_bytes = sizes[r]; + auto rc = Broadcast(comm, recv.subspan(offset, as_bytes), r); + if (!rc.OK()) { + return rc; + } + offset += as_bytes; + } + return Success(); +} +} // namespace cpu_impl + +namespace detail { [[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span sizes, - common::Span data, - common::Span offset, + common::Span offset, common::Span erased_result) { auto world = comm.World(); + if (world == 1) { + return Success(); + } auto rank = comm.Rank(); auto prev = BootstrapPrev(rank, comm.World()); @@ -56,17 +77,6 @@ Result RingAllgather(Comm const& comm, common::Span data, std::size auto prev_ch = comm.Chan(prev); auto next_ch = comm.Chan(next); - // get worker offset - CHECK_EQ(world + 1, offset.size()); - std::fill_n(offset.data(), offset.size(), 0); - std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1); - CHECK_EQ(*offset.cbegin(), 0); - - // copy data - auto current = erased_result.subspan(offset[rank], data.size_bytes()); - auto erased_data = EraseType(data); - std::copy_n(erased_data.data(), erased_data.size(), current.data()); - for (std::int32_t r = 0; r < world; ++r) { auto send_rank = (rank + world - r) % world; auto send_off = offset[send_rank]; @@ -87,4 +97,5 @@ Result RingAllgather(Comm const& comm, common::Span data, std::size } return comm.Block(); } -} // namespace xgboost::collective::cpu_impl +} // namespace detail +} // namespace xgboost::collective diff --git a/src/collective/allgather.h b/src/collective/allgather.h index a566da78d..4f13014be 100644 --- a/src/collective/allgather.h +++ b/src/collective/allgather.h @@ -9,28 +9,47 @@ #include // for remove_cv_t #include // for vector -#include "../common/type.h" // for EraseType +#include "../common/type.h" // for EraseType #include "comm.h" // for Comm, Channel #include "xgboost/collective/result.h" // for Result -#include "xgboost/span.h" // for Span +#include "xgboost/linalg.h" +#include "xgboost/span.h" // for Span namespace xgboost::collective { namespace cpu_impl { /** - * @param worker_off Segment offset. For example, if the rank 2 worker specifis worker_off - * = 1, then it owns the third segment. + * @param worker_off Segment offset. For example, if the rank 2 worker specifies + * worker_off = 1, then it owns the third segment. */ [[nodiscard]] Result RingAllgather(Comm const& comm, common::Span data, std::size_t segment_size, std::int32_t worker_off, std::shared_ptr prev_ch, std::shared_ptr next_ch); -[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span sizes, - common::Span data, - common::Span offset, - common::Span erased_result); +/** + * @brief Implement allgather-v using broadcast. + * + * https://arxiv.org/abs/1812.05964 + */ +Result BroadcastAllgatherV(Comm const& comm, common::Span sizes, + common::Span recv); } // namespace cpu_impl +namespace detail { +inline void AllgatherVOffset(common::Span sizes, + common::Span offset) { + // get worker offset + std::fill_n(offset.data(), offset.size(), 0); + std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1); + CHECK_EQ(*offset.cbegin(), 0); +} + +// An implementation that's used by both cpu and gpu +[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span sizes, + common::Span offset, + common::Span erased_result); +} // namespace detail + template [[nodiscard]] Result RingAllgather(Comm const& comm, common::Span data, std::size_t size) { auto n_bytes = sizeof(T) * size; @@ -68,9 +87,15 @@ template auto h_result = common::Span{result.data(), result.size()}; auto erased_result = common::EraseType(h_result); auto erased_data = common::EraseType(data); - std::vector offset(world + 1); + std::vector recv_segments(world + 1); + auto s_segments = common::Span{recv_segments.data(), recv_segments.size()}; - return cpu_impl::RingAllgatherV(comm, sizes, erased_data, - common::Span{offset.data(), offset.size()}, erased_result); + // get worker offset + detail::AllgatherVOffset(sizes, s_segments); + // copy data + auto current = erased_result.subspan(recv_segments[rank], data.size_bytes()); + std::copy_n(erased_data.data(), erased_data.size(), current.data()); + + return detail::RingAllgatherV(comm, sizes, s_segments, erased_result); } } // namespace xgboost::collective diff --git a/src/collective/coll.cc b/src/collective/coll.cc index 6682e57ff..598e6129d 100644 --- a/src/collective/coll.cc +++ b/src/collective/coll.cc @@ -8,16 +8,14 @@ #include // for int8_t, int64_t #include // for bit_and, bit_or, bit_xor, plus -#include "allgather.h" // for RingAllgatherV, RingAllgather -#include "allreduce.h" // for Allreduce -#include "broadcast.h" // for Broadcast -#include "comm.h" // for Comm -#include "xgboost/context.h" // for Context +#include "allgather.h" // for RingAllgatherV, RingAllgather +#include "allreduce.h" // for Allreduce +#include "broadcast.h" // for Broadcast +#include "comm.h" // for Comm namespace xgboost::collective { -[[nodiscard]] Result Coll::Allreduce(Context const*, Comm const& comm, - common::Span data, ArrayInterfaceHandler::Type, - Op op) { +[[nodiscard]] Result Coll::Allreduce(Comm const& comm, common::Span data, + ArrayInterfaceHandler::Type, Op op) { namespace coll = ::xgboost::collective; auto redop_fn = [](auto lhs, auto out, auto elem_op) { @@ -55,21 +53,45 @@ namespace xgboost::collective { return comm.Block(); } -[[nodiscard]] Result Coll::Broadcast(Context const*, Comm const& comm, - common::Span data, std::int32_t root) { +[[nodiscard]] Result Coll::Broadcast(Comm const& comm, common::Span data, + std::int32_t root) { return cpu_impl::Broadcast(comm, data, root); } -[[nodiscard]] Result Coll::Allgather(Context const*, Comm const& comm, - common::Span data, std::size_t size) { +[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span data, + std::int64_t size) { return RingAllgather(comm, data, size); } -[[nodiscard]] Result Coll::AllgatherV(Context const*, Comm const& comm, - common::Span data, +[[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span data, common::Span sizes, common::Span recv_segments, - common::Span recv) { - return cpu_impl::RingAllgatherV(comm, sizes, data, recv_segments, recv); + common::Span recv, AllgatherVAlgo algo) { + // get worker offset + detail::AllgatherVOffset(sizes, recv_segments); + + // copy data + auto current = recv.subspan(recv_segments[comm.Rank()], data.size_bytes()); + if (current.data() != data.data()) { + std::copy_n(data.data(), data.size(), current.data()); + } + + switch (algo) { + case AllgatherVAlgo::kRing: + return detail::RingAllgatherV(comm, sizes, recv_segments, recv); + case AllgatherVAlgo::kBcast: + return cpu_impl::BroadcastAllgatherV(comm, sizes, recv); + default: { + return Fail("Unknown algorithm for allgather-v"); + } + } } + +#if !defined(XGBOOST_USE_NCCL) +Coll* Coll::MakeCUDAVar() { + LOG(FATAL) << "NCCL is required for device communication."; + return nullptr; +} +#endif + } // namespace xgboost::collective diff --git a/src/collective/coll.cu b/src/collective/coll.cu new file mode 100644 index 000000000..bac9fb094 --- /dev/null +++ b/src/collective/coll.cu @@ -0,0 +1,254 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#if defined(XGBOOST_USE_NCCL) +#include // for int8_t, int64_t + +#include "../common/cuda_context.cuh" +#include "../common/device_helpers.cuh" +#include "../data/array_interface.h" +#include "allgather.h" // for AllgatherVOffset +#include "coll.cuh" +#include "comm.cuh" +#include "nccl.h" +#include "xgboost/collective/result.h" // for Result +#include "xgboost/span.h" // for Span + +namespace xgboost::collective { +Coll* Coll::MakeCUDAVar() { return new NCCLColl{}; } + +NCCLColl::~NCCLColl() = default; +namespace { +Result GetNCCLResult(ncclResult_t code) { + if (code == ncclSuccess) { + return Success(); + } + + std::stringstream ss; + ss << "NCCL failure: " << ncclGetErrorString(code) << "."; + if (code == ncclUnhandledCudaError) { + // nccl usually preserves the last error so we can get more details. + auto err = cudaPeekAtLastError(); + ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n"; + } else if (code == ncclSystemError) { + ss << " This might be caused by a network configuration issue. Please consider specifying " + "the network interface for NCCL via environment variables listed in its reference: " + "`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n"; + } + return Fail(ss.str()); +} + +auto GetNCCLType(ArrayInterfaceHandler::Type type) { + auto fatal = [] { + LOG(FATAL) << "Invalid type for NCCL operation."; + return ncclHalf; // dummy return to silent the compiler warning. + }; + using H = ArrayInterfaceHandler; + switch (type) { + case H::kF2: + return ncclHalf; + case H::kF4: + return ncclFloat32; + case H::kF8: + return ncclFloat64; + case H::kF16: + return fatal(); + case H::kI1: + return ncclInt8; + case H::kI2: + return fatal(); + case H::kI4: + return ncclInt32; + case H::kI8: + return ncclInt64; + case H::kU1: + return ncclUint8; + case H::kU2: + return fatal(); + case H::kU4: + return ncclUint32; + case H::kU8: + return ncclUint64; + } + return fatal(); +} + +bool IsBitwiseOp(Op const& op) { + return op == Op::kBitwiseAND || op == Op::kBitwiseOR || op == Op::kBitwiseXOR; +} + +template +void RunBitwiseAllreduce(dh::CUDAStreamView stream, common::Span out_buffer, + std::int8_t const* device_buffer, Func func, std::int32_t world_size, + std::size_t size) { + dh::LaunchN(size, stream, [=] __device__(std::size_t idx) { + auto result = device_buffer[idx]; + for (auto rank = 1; rank < world_size; rank++) { + result = func(result, device_buffer[rank * size + idx]); + } + out_buffer[idx] = result; + }); +} + +[[nodiscard]] Result BitwiseAllReduce(NCCLComm const* pcomm, ncclComm_t handle, + common::Span data, Op op) { + dh::device_vector buffer(data.size() * pcomm->World()); + auto* device_buffer = buffer.data().get(); + + // First gather data from all the workers. + CHECK(handle); + auto rc = GetNCCLResult( + ncclAllGather(data.data(), device_buffer, data.size(), ncclInt8, handle, pcomm->Stream())); + if (!rc.OK()) { + return rc; + } + + // Then reduce locally. + switch (op) { + case Op::kBitwiseAND: + RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_and(), + pcomm->World(), data.size()); + break; + case Op::kBitwiseOR: + RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_or(), + pcomm->World(), data.size()); + break; + case Op::kBitwiseXOR: + RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_xor(), + pcomm->World(), data.size()); + break; + default: + LOG(FATAL) << "Not a bitwise reduce operation."; + } + return Success(); +} + +ncclRedOp_t GetNCCLRedOp(Op const& op) { + ncclRedOp_t result{ncclMax}; + switch (op) { + case Op::kMax: + result = ncclMax; + break; + case Op::kMin: + result = ncclMin; + break; + case Op::kSum: + result = ncclSum; + break; + default: + LOG(FATAL) << "Unsupported reduce operation."; + } + return result; +} +} // namespace + +[[nodiscard]] Result NCCLColl::Allreduce(Comm const& comm, common::Span data, + ArrayInterfaceHandler::Type type, Op op) { + if (!comm.IsDistributed()) { + return Success(); + } + auto nccl = dynamic_cast(&comm); + CHECK(nccl); + return Success() << [&] { + if (IsBitwiseOp(op)) { + return BitwiseAllReduce(nccl, nccl->Handle(), data, op); + } else { + return DispatchDType(type, [=](auto t) { + using T = decltype(t); + auto rdata = common::RestoreType(data); + auto rc = ncclAllReduce(data.data(), data.data(), rdata.size(), GetNCCLType(type), + GetNCCLRedOp(op), nccl->Handle(), nccl->Stream()); + return GetNCCLResult(rc); + }); + } + } << [&] { return nccl->Block(); }; +} + +[[nodiscard]] Result NCCLColl::Broadcast(Comm const& comm, common::Span data, + std::int32_t root) { + if (!comm.IsDistributed()) { + return Success(); + } + auto nccl = dynamic_cast(&comm); + CHECK(nccl); + return Success() << [&] { + return GetNCCLResult(ncclBroadcast(data.data(), data.data(), data.size_bytes(), ncclInt8, root, + nccl->Handle(), nccl->Stream())); + } << [&] { return nccl->Block(); }; +} + +[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span data, + std::int64_t size) { + if (!comm.IsDistributed()) { + return Success(); + } + auto nccl = dynamic_cast(&comm); + CHECK(nccl); + auto send = data.subspan(comm.Rank() * size, size); + return Success() << [&] { + return GetNCCLResult( + ncclAllGather(send.data(), data.data(), size, ncclInt8, nccl->Handle(), nccl->Stream())); + } << [&] { return nccl->Block(); }; +} + +namespace cuda_impl { +/** + * @brief Implement allgather-v using broadcast. + * + * https://arxiv.org/abs/1812.05964 + */ +Result BroadcastAllgatherV(NCCLComm const* comm, common::Span data, + common::Span sizes, common::Span recv) { + return Success() << [] { return GetNCCLResult(ncclGroupStart()); } << [&] { + std::size_t offset = 0; + for (std::int32_t r = 0; r < comm->World(); ++r) { + auto as_bytes = sizes[r]; + auto rc = ncclBroadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes, + ncclInt8, r, comm->Handle(), dh::DefaultStream()); + if (rc != ncclSuccess) { + return GetNCCLResult(rc); + } + offset += as_bytes; + } + return Success(); + } << [] { return GetNCCLResult(ncclGroupEnd()); }; +} +} // namespace cuda_impl + +[[nodiscard]] Result NCCLColl::AllgatherV(Comm const& comm, common::Span data, + common::Span sizes, + common::Span recv_segments, + common::Span recv, AllgatherVAlgo algo) { + auto nccl = dynamic_cast(&comm); + CHECK(nccl); + if (!comm.IsDistributed()) { + return Success(); + } + + switch (algo) { + case AllgatherVAlgo::kRing: { + return Success() << [] { return GetNCCLResult(ncclGroupStart()); } << [&] { + // get worker offset + detail::AllgatherVOffset(sizes, recv_segments); + // copy data + auto current = recv.subspan(recv_segments[comm.Rank()], data.size_bytes()); + if (current.data() != data.data()) { + dh::safe_cuda(cudaMemcpyAsync(current.data(), data.data(), current.size_bytes(), + cudaMemcpyDeviceToDevice, nccl->Stream())); + } + return detail::RingAllgatherV(comm, sizes, recv_segments, recv); + } << [] { + return GetNCCLResult(ncclGroupEnd()); + } << [&] { return nccl->Block(); }; + } + case AllgatherVAlgo::kBcast: { + return cuda_impl::BroadcastAllgatherV(nccl, data, sizes, recv); + } + default: { + return Fail("Unknown algorithm for allgather-v"); + } + } +} +} // namespace xgboost::collective + +#endif // defined(XGBOOST_USE_NCCL) diff --git a/src/collective/coll.cuh b/src/collective/coll.cuh new file mode 100644 index 000000000..87fb46711 --- /dev/null +++ b/src/collective/coll.cuh @@ -0,0 +1,29 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#pragma once + +#include // for int8_t, int64_t + +#include "../data/array_interface.h" // for ArrayInterfaceHandler +#include "coll.h" // for Coll +#include "comm.h" // for Comm +#include "xgboost/span.h" // for Span + +namespace xgboost::collective { +class NCCLColl : public Coll { + public: + ~NCCLColl() override; + + [[nodiscard]] Result Allreduce(Comm const& comm, common::Span data, + ArrayInterfaceHandler::Type type, Op op) override; + [[nodiscard]] Result Broadcast(Comm const& comm, common::Span data, + std::int32_t root) override; + [[nodiscard]] Result Allgather(Comm const& comm, common::Span data, + std::int64_t size) override; + [[nodiscard]] Result AllgatherV(Comm const& comm, common::Span data, + common::Span sizes, + common::Span recv_segments, + common::Span recv, AllgatherVAlgo algo) override; +}; +} // namespace xgboost::collective diff --git a/src/collective/coll.h b/src/collective/coll.h index 9a318db8d..0189ffd5e 100644 --- a/src/collective/coll.h +++ b/src/collective/coll.h @@ -2,17 +2,20 @@ * Copyright 2023, XGBoost Contributors */ #pragma once -#include // for size_t #include // for int8_t, int64_t #include // for enable_shared_from_this #include "../data/array_interface.h" // for ArrayInterfaceHandler #include "comm.h" // for Comm #include "xgboost/collective/result.h" // for Result -#include "xgboost/context.h" // for Context #include "xgboost/span.h" // for Span namespace xgboost::collective { +enum class AllgatherVAlgo { + kRing = 0, // use ring-based allgather-v + kBcast = 1, // use broadcast-based allgather-v +}; + /** * @brief Interface and base implementation for collective. */ @@ -21,6 +24,8 @@ class Coll : public std::enable_shared_from_this { Coll() = default; virtual ~Coll() noexcept(false) {} // NOLINT + Coll* MakeCUDAVar(); + /** * @brief Allreduce * @@ -29,8 +34,7 @@ class Coll : public std::enable_shared_from_this { * @param [in] op Reduce operation. For custom operation, user needs to reach down to * the CPU implementation. */ - [[nodiscard]] virtual Result Allreduce(Context const* ctx, Comm const& comm, - common::Span data, + [[nodiscard]] virtual Result Allreduce(Comm const& comm, common::Span data, ArrayInterfaceHandler::Type type, Op op); /** * @brief Broadcast @@ -38,29 +42,29 @@ class Coll : public std::enable_shared_from_this { * @param [in,out] data Data buffer for input and output. * @param [in] root Root rank for broadcast. */ - [[nodiscard]] virtual Result Broadcast(Context const* ctx, Comm const& comm, - common::Span data, std::int32_t root); + [[nodiscard]] virtual Result Broadcast(Comm const& comm, common::Span data, + std::int32_t root); /** * @brief Allgather * * @param [in,out] data Data buffer for input and output. * @param [in] size Size of data for each worker. */ - [[nodiscard]] virtual Result Allgather(Context const* ctx, Comm const& comm, - common::Span data, std::size_t size); + [[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span data, + std::int64_t size); /** * @brief Allgather with variable length. * * @param [in] data Input data for the current worker. * @param [in] sizes Size of the input from each worker. - * @param [out] recv_segments pre-allocated offset for each worker in the output, size - * should be equal to (world + 1). + * @param [out] recv_segments pre-allocated offset buffer for each worker in the output, + * size should be equal to (world + 1). GPU ring-based implementation + * doesn't use the buffer. * @param [out] recv pre-allocated buffer for output. */ - [[nodiscard]] virtual Result AllgatherV(Context const* ctx, Comm const& comm, - common::Span data, + [[nodiscard]] virtual Result AllgatherV(Comm const& comm, common::Span data, common::Span sizes, common::Span recv_segments, - common::Span recv); + common::Span recv, AllgatherVAlgo algo); }; } // namespace xgboost::collective diff --git a/src/collective/comm.cc b/src/collective/comm.cc index 3c49303fa..dbd45cbb2 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -262,7 +262,7 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se } RabitComm::~RabitComm() noexcept(false) { - if (!IsDistributed()) { + if (!this->IsDistributed()) { return; } auto rc = this->Shutdown(); diff --git a/src/collective/comm.cu b/src/collective/comm.cu new file mode 100644 index 000000000..31a06e124 --- /dev/null +++ b/src/collective/comm.cu @@ -0,0 +1,112 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#if defined(XGBOOST_USE_NCCL) +#include // for sort +#include // for size_t +#include // for uint64_t, int8_t +#include // for memcpy +#include // for shared_ptr +#include // for stringstream +#include // for vector + +#include "../common/device_helpers.cuh" // for DefaultStream +#include "../common/type.h" // for EraseType +#include "broadcast.h" // for Broadcast +#include "comm.cuh" // for NCCLComm +#include "comm.h" // for Comm +#include "xgboost/collective/result.h" // for Result +#include "xgboost/span.h" // for Span + +namespace xgboost::collective { +namespace { +Result GetUniqueId(Comm const& comm, ncclUniqueId* pid) { + static const int kRootRank = 0; + ncclUniqueId id; + if (comm.Rank() == kRootRank) { + dh::safe_nccl(ncclGetUniqueId(&id)); + } + auto rc = Broadcast(comm, common::Span{reinterpret_cast(&id), sizeof(ncclUniqueId)}, + kRootRank); + if (!rc.OK()) { + return rc; + } + *pid = id; + return Success(); +} + +inline constexpr std::size_t kUuidLength = + sizeof(std::declval().uuid) / sizeof(std::uint64_t); + +void GetCudaUUID(xgboost::common::Span const& uuid, DeviceOrd device) { + cudaDeviceProp prob{}; + dh::safe_cuda(cudaGetDeviceProperties(&prob, device.ordinal)); + std::memcpy(uuid.data(), static_cast(&(prob.uuid)), sizeof(prob.uuid)); +} + +static std::string PrintUUID(xgboost::common::Span const& uuid) { + std::stringstream ss; + for (auto v : uuid) { + ss << std::hex << v; + } + return ss.str(); +} +} // namespace + +Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr pimpl) { + return new NCCLComm{ctx, *this, pimpl}; +} + +NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr pimpl) + : Comm{root.TrackerInfo().host, root.TrackerInfo().port, root.Timeout(), root.Retry(), + root.TaskID()}, + stream_{dh::DefaultStream()} { + this->world_ = root.World(); + this->rank_ = root.Rank(); + this->domain_ = root.Domain(); + if (!root.IsDistributed()) { + return; + } + + dh::safe_cuda(cudaSetDevice(ctx->Ordinal())); + + std::vector uuids(root.World() * kUuidLength, 0); + auto s_uuid = xgboost::common::Span{uuids.data(), uuids.size()}; + auto s_this_uuid = s_uuid.subspan(root.Rank() * kUuidLength, kUuidLength); + GetCudaUUID(s_this_uuid, ctx->Device()); + + auto rc = pimpl->Allgather(root, common::EraseType(s_uuid), s_this_uuid.size_bytes()); + CHECK(rc.OK()) << rc.Report(); + + std::vector> converted(root.World()); + std::size_t j = 0; + for (size_t i = 0; i < uuids.size(); i += kUuidLength) { + converted[j] = s_uuid.subspan(i, kUuidLength); + j++; + } + + std::sort(converted.begin(), converted.end()); + auto iter = std::unique(converted.begin(), converted.end()); + auto n_uniques = std::distance(converted.begin(), iter); + + CHECK_EQ(n_uniques, root.World()) + << "Multiple processes within communication group running on same CUDA " + << "device is not supported. " << PrintUUID(s_this_uuid) << "\n"; + + rc = GetUniqueId(root, &nccl_unique_id_); + CHECK(rc.OK()) << rc.Report(); + dh::safe_nccl(ncclCommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank())); + + for (std::int32_t r = 0; r < root.World(); ++r) { + this->channels_.emplace_back( + std::make_shared(root, r, nccl_comm_, dh::DefaultStream())); + } +} + +NCCLComm::~NCCLComm() { + if (nccl_comm_) { + dh::safe_nccl(ncclCommDestroy(nccl_comm_)); + } +} +} // namespace xgboost::collective +#endif // defined(XGBOOST_USE_NCCL) diff --git a/src/collective/comm.cuh b/src/collective/comm.cuh new file mode 100644 index 000000000..ea15c50f3 --- /dev/null +++ b/src/collective/comm.cuh @@ -0,0 +1,67 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#pragma once + +#ifdef XGBOOST_USE_NCCL +#include "nccl.h" +#endif // XGBOOST_USE_NCCL +#include "../common/device_helpers.cuh" +#include "coll.h" +#include "comm.h" +#include "xgboost/context.h" + +namespace xgboost::collective { + +inline Result GetCUDAResult(cudaError rc) { + if (rc == cudaSuccess) { + return Success(); + } + std::string msg = thrust::system_error(rc, thrust::cuda_category()).what(); + return Fail(msg); +} + +class NCCLComm : public Comm { + ncclComm_t nccl_comm_{nullptr}; + ncclUniqueId nccl_unique_id_{}; + dh::CUDAStreamView stream_; + + public: + [[nodiscard]] ncclComm_t Handle() const { return nccl_comm_; } + + explicit NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr pimpl); + [[nodiscard]] Result LogTracker(std::string) const override { + LOG(FATAL) << "Device comm is used for logging."; + return Fail("Undefined."); + } + ~NCCLComm() override; + [[nodiscard]] bool IsFederated() const override { return false; } + [[nodiscard]] dh::CUDAStreamView Stream() const { return stream_; } + [[nodiscard]] Result Block() const override { + auto rc = this->Stream().Sync(false); + return GetCUDAResult(rc); + } +}; + +class NCCLChannel : public Channel { + std::int32_t rank_{-1}; + ncclComm_t nccl_comm_{}; + dh::CUDAStreamView stream_; + + public: + explicit NCCLChannel(Comm const& comm, std::int32_t rank, ncclComm_t nccl_comm, + dh::CUDAStreamView stream) + : rank_{rank}, nccl_comm_{nccl_comm}, Channel{comm, nullptr}, stream_{stream} {} + + void SendAll(std::int8_t const* ptr, std::size_t n) override { + dh::safe_nccl(ncclSend(ptr, n, ncclInt8, rank_, nccl_comm_, stream_)); + } + void RecvAll(std::int8_t* ptr, std::size_t n) override { + dh::safe_nccl(ncclRecv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_)); + } + [[nodiscard]] Result Block() override { + auto rc = stream_.Sync(false); + return GetCUDAResult(rc); + } +}; +} // namespace xgboost::collective diff --git a/src/collective/comm.h b/src/collective/comm.h index adf23b9e4..afb543c46 100644 --- a/src/collective/comm.h +++ b/src/collective/comm.h @@ -2,20 +2,20 @@ * Copyright 2023, XGBoost Contributors */ #pragma once -#include // for seconds -#include // for size_t -#include // for int32_t -#include // for shared_ptr -#include // for string -#include // for thread -#include // for remove_const_t -#include // for move -#include // for vector +#include // for seconds +#include // for size_t +#include // for int32_t +#include // for shared_ptr +#include // for string +#include // for thread +#include // for move +#include // for vector #include "loop.h" // for Loop #include "protocol.h" // for PeerInfo #include "xgboost/collective/result.h" // for Result #include "xgboost/collective/socket.h" // for TCPSocket +#include "xgboost/context.h" // for Context #include "xgboost/span.h" // for Span namespace xgboost::collective { @@ -35,13 +35,14 @@ inline std::int32_t BootstrapPrev(std::int32_t r, std::int32_t world) { } class Channel; +class Coll; /** * @brief Base communicator storing info about the tracker and other communicators. */ class Comm { protected: - std::int32_t world_{1}; + std::int32_t world_{-1}; std::int32_t rank_{0}; std::chrono::seconds timeout_{DefaultTimeoutSec()}; std::int32_t retry_{DefaultRetry()}; @@ -69,12 +70,14 @@ class Comm { [[nodiscard]] Result ConnectTracker(TCPSocket* out) const; [[nodiscard]] auto Domain() const { return domain_; } [[nodiscard]] auto Timeout() const { return timeout_; } + [[nodiscard]] auto Retry() const { return retry_; } + [[nodiscard]] auto TaskID() const { return task_id_; } [[nodiscard]] auto Rank() const { return rank_; } - [[nodiscard]] auto World() const { return world_; } - [[nodiscard]] bool IsDistributed() const { return World() > 1; } + [[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; } + [[nodiscard]] bool IsDistributed() const { return world_ != -1; } void Submit(Loop::Op op) const { loop_->Submit(op); } - [[nodiscard]] Result Block() const { return loop_->Block(); } + [[nodiscard]] virtual Result Block() const { return loop_->Block(); } [[nodiscard]] virtual std::shared_ptr Chan(std::int32_t rank) const { return channels_.at(rank); @@ -83,6 +86,8 @@ class Comm { [[nodiscard]] virtual Result LogTracker(std::string msg) const = 0; [[nodiscard]] virtual Result SignalError(Result const&) { return Success(); } + + Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr pimpl); }; class RabitComm : public Comm { @@ -116,7 +121,7 @@ class Channel { explicit Channel(Comm const& comm, std::shared_ptr sock) : sock_{std::move(sock)}, comm_{comm} {} - void SendAll(std::int8_t const* ptr, std::size_t n) { + virtual void SendAll(std::int8_t const* ptr, std::size_t n) { Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast(ptr), n, sock_.get(), 0}; CHECK(sock_.get()); comm_.Submit(std::move(op)); @@ -125,7 +130,7 @@ class Channel { this->SendAll(data.data(), data.size_bytes()); } - void RecvAll(std::int8_t* ptr, std::size_t n) { + virtual void RecvAll(std::int8_t* ptr, std::size_t n) { Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0}; CHECK(sock_.get()); comm_.Submit(std::move(op)); @@ -133,7 +138,7 @@ class Channel { void RecvAll(common::Span data) { this->RecvAll(data.data(), data.size_bytes()); } [[nodiscard]] auto Socket() const { return sock_; } - [[nodiscard]] Result Block() { return comm_.Block(); } + [[nodiscard]] virtual Result Block() { return comm_.Block(); } }; enum class Op { kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5 }; diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index dfaac9c35..89b3ad2e6 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -1169,7 +1169,13 @@ class CUDAStreamView { operator cudaStream_t() const { // NOLINT return stream_; } - void Sync() { dh::safe_cuda(cudaStreamSynchronize(stream_)); } + cudaError_t Sync(bool error = true) { + if (error) { + dh::safe_cuda(cudaStreamSynchronize(stream_)); + return cudaSuccess; + } + return cudaStreamSynchronize(stream_); + } }; inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 8fd3120b5..6db201dd5 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -19,7 +19,6 @@ #include "../common/cuda_context.cuh" // CUDAContext #include "../common/device_helpers.cuh" #include "../common/hist_util.h" -#include "../common/io.h" #include "../common/timer.h" #include "../data/ellpack_page.cuh" #include "../data/ellpack_page.h" @@ -39,7 +38,6 @@ #include "xgboost/data.h" #include "xgboost/host_device_vector.h" #include "xgboost/json.h" -#include "xgboost/parameter.h" #include "xgboost/span.h" #include "xgboost/task.h" // for ObjInfo #include "xgboost/tree_model.h" diff --git a/tests/cpp/collective/test_allgather.cc b/tests/cpp/collective/test_allgather.cc index a74b9f149..bdfadc0c7 100644 --- a/tests/cpp/collective/test_allgather.cc +++ b/tests/cpp/collective/test_allgather.cc @@ -14,6 +14,7 @@ #include // for vector #include "../../../src/collective/allgather.h" // for RingAllgather +#include "../../../src/collective/coll.h" // for Coll #include "../../../src/collective/comm.h" // for RabitComm #include "gtest/gtest.h" // for AssertionR... #include "test_worker.h" // for TestDistri... @@ -63,37 +64,79 @@ class Worker : public WorkerForTest { } } - void TestV() { - { - // basic test - std::int32_t n{comm_.Rank()}; - std::vector result; - auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result); - ASSERT_TRUE(rc.OK()) << rc.Report(); - for (std::int32_t i = 0; i < comm_.World(); ++i) { - ASSERT_EQ(result[i], i); - } - } - - { - // V test - std::vector data(comm_.Rank() + 1, comm_.Rank()); - std::vector result; - auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result); - ASSERT_TRUE(rc.OK()) << rc.Report(); - ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2); - std::int32_t k{0}; - for (std::int32_t r = 0; r < comm_.World(); ++r) { - auto seg = common::Span{result.data(), result.size()}.subspan(k, (r + 1)); - if (comm_.Rank() == 0) { - for (auto v : seg) { - ASSERT_EQ(v, r); - } - k += seg.size(); + void CheckV(common::Span result) { + std::int32_t k{0}; + for (std::int32_t r = 0; r < comm_.World(); ++r) { + auto seg = common::Span{result.data(), result.size()}.subspan(k, (r + 1)); + if (comm_.Rank() == 0) { + for (auto v : seg) { + ASSERT_EQ(v, r); } + k += seg.size(); } } } + void TestVRing() { + // V test + std::vector data(comm_.Rank() + 1, comm_.Rank()); + std::vector result; + auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result); + ASSERT_TRUE(rc.OK()) << rc.Report(); + ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2); + CheckV(result); + } + + void TestVBasic() { + // basic test + std::int32_t n{comm_.Rank()}; + std::vector result; + auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result); + ASSERT_TRUE(rc.OK()) << rc.Report(); + for (std::int32_t i = 0; i < comm_.World(); ++i) { + ASSERT_EQ(result[i], i); + } + } + + void TestVAlgo() { + // V test, broadcast + std::vector data(comm_.Rank() + 1, comm_.Rank()); + auto s_data = common::Span{data.data(), data.size()}; + + std::vector sizes(comm_.World(), 0); + sizes[comm_.Rank()] = s_data.size_bytes(); + auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1); + ASSERT_TRUE(rc.OK()) << rc.Report(); + std::shared_ptr pcoll{new Coll{}}; + + std::vector recv_segments(comm_.World() + 1, 0); + std::vector recv(std::accumulate(sizes.cbegin(), sizes.cend(), 0)); + + auto s_recv = common::Span{recv.data(), recv.size()}; + + rc = pcoll->AllgatherV(comm_, common::EraseType(s_data), + common::Span{sizes.data(), sizes.size()}, + common::Span{recv_segments.data(), recv_segments.size()}, + common::EraseType(s_recv), AllgatherVAlgo::kBcast); + ASSERT_TRUE(rc.OK()); + CheckV(s_recv); + + // Test inplace + auto test_inplace = [&] (AllgatherVAlgo algo) { + std::fill_n(s_recv.data(), s_recv.size(), 0); + auto current = s_recv.subspan(recv_segments[comm_.Rank()], + recv_segments[comm_.Rank() + 1] - recv_segments[comm_.Rank()]); + std::copy_n(data.data(), data.size(), current.data()); + rc = pcoll->AllgatherV(comm_, common::EraseType(current), + common::Span{sizes.data(), sizes.size()}, + common::Span{recv_segments.data(), recv_segments.size()}, + common::EraseType(s_recv), algo); + ASSERT_TRUE(rc.OK()); + CheckV(s_recv); + }; + + test_inplace(AllgatherVAlgo::kBcast); + test_inplace(AllgatherVAlgo::kRing); + } }; } // namespace @@ -106,12 +149,30 @@ TEST_F(AllgatherTest, Basic) { }); } -TEST_F(AllgatherTest, V) { +TEST_F(AllgatherTest, VBasic) { std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency()); TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) { Worker worker{host, port, timeout, n_workers, r}; - worker.TestV(); + worker.TestVBasic(); + }); +} + +TEST_F(AllgatherTest, VRing) { + std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency()); + TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + Worker worker{host, port, timeout, n_workers, r}; + worker.TestVRing(); + }); +} + +TEST_F(AllgatherTest, VAlgo) { + std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency()); + TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + Worker worker{host, port, timeout, n_workers, r}; + worker.TestVAlgo(); }); } } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_allgather.cu b/tests/cpp/collective/test_allgather.cu new file mode 100644 index 000000000..48f7c2615 --- /dev/null +++ b/tests/cpp/collective/test_allgather.cu @@ -0,0 +1,117 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#if defined(XGBOOST_USE_NCCL) +#include +#include // for device_vector +#include // for equal +#include // for Span + +#include // for size_t +#include // for int32_t, int64_t +#include // for vector + +#include "../../../src/collective/allgather.h" // for RingAllgather +#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector +#include "../../../src/common/type.h" // for EraseType +#include "test_worker.cuh" // for NCCLWorkerForTest +#include "test_worker.h" // for TestDistributed, WorkerForTest + +namespace xgboost::collective { +namespace { +class Worker : public NCCLWorkerForTest { + public: + using NCCLWorkerForTest::NCCLWorkerForTest; + + void TestV(AllgatherVAlgo algo) { + { + // basic test + std::size_t n = 1; + // create data + dh::device_vector data(n, comm_.Rank()); + auto s_data = common::EraseType(common::Span{data.data().get(), data.size()}); + // get size + std::vector sizes(comm_.World(), -1); + sizes[comm_.Rank()] = s_data.size_bytes(); + auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1); + ASSERT_TRUE(rc.OK()) << rc.Report(); + // create result + dh::device_vector result(comm_.World(), -1); + auto s_result = common::EraseType(dh::ToSpan(result)); + + std::vector recv_seg(nccl_comm_->World() + 1, 0); + rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()}, + common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo); + ASSERT_TRUE(rc.OK()) << rc.Report(); + + for (std::int32_t i = 0; i < comm_.World(); ++i) { + ASSERT_EQ(result[i], i); + } + } + { + // V test + std::size_t n = 256 * 256; + // create data + dh::device_vector data(n * nccl_comm_->Rank(), nccl_comm_->Rank()); + auto s_data = common::EraseType(common::Span{data.data().get(), data.size()}); + // get size + std::vector sizes(nccl_comm_->World(), 0); + sizes[comm_.Rank()] = dh::ToSpan(data).size_bytes(); + auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1); + ASSERT_TRUE(rc.OK()) << rc.Report(); + auto n_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0); + // create result + dh::device_vector result(n_bytes / sizeof(std::int32_t), -1); + auto s_result = common::EraseType(dh::ToSpan(result)); + + std::vector recv_seg(nccl_comm_->World() + 1, 0); + rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()}, + common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo); + ASSERT_TRUE(rc.OK()) << rc.Report(); + // check segment size + if (algo != AllgatherVAlgo::kBcast) { + auto size = recv_seg[nccl_comm_->Rank() + 1] - recv_seg[nccl_comm_->Rank()]; + ASSERT_EQ(size, n * nccl_comm_->Rank() * sizeof(std::int32_t)); + ASSERT_EQ(size, sizes[nccl_comm_->Rank()]); + } + // check data + std::size_t k{0}; + for (std::int32_t r = 0; r < nccl_comm_->World(); ++r) { + std::size_t s = n * r; + auto current = dh::ToSpan(result).subspan(k, s); + std::vector h_data(current.size()); + dh::CopyDeviceSpanToVector(&h_data, current); + for (auto v : h_data) { + ASSERT_EQ(v, r); + } + k += s; + } + } + } +}; + +class AllgatherTestGPU : public SocketTest {}; +} // namespace + +TEST_F(AllgatherTestGPU, MGPUTestVRing) { + auto n_workers = common::AllVisibleGPUs(); + TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + Worker w{host, port, timeout, n_workers, r}; + w.Setup(); + w.TestV(AllgatherVAlgo::kRing); + w.TestV(AllgatherVAlgo::kBcast); + }); +} + +TEST_F(AllgatherTestGPU, MGPUTestVBcast) { + auto n_workers = common::AllVisibleGPUs(); + TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + Worker w{host, port, timeout, n_workers, r}; + w.Setup(); + w.TestV(AllgatherVAlgo::kBcast); + }); +} +} // namespace xgboost::collective +#endif // defined(XGBOOST_USE_NCCL) diff --git a/tests/cpp/collective/test_allreduce.cc b/tests/cpp/collective/test_allreduce.cc index 77d23f6fe..744608dec 100644 --- a/tests/cpp/collective/test_allreduce.cc +++ b/tests/cpp/collective/test_allreduce.cc @@ -6,10 +6,10 @@ #include "../../../src/collective/allreduce.h" #include "../../../src/collective/coll.h" // for Coll #include "../../../src/collective/tracker.h" -#include "test_worker.h" // for WorkerForTest, TestDistributed +#include "../../../src/common/type.h" // for EraseType +#include "test_worker.h" // for WorkerForTest, TestDistributed namespace xgboost::collective { - namespace { class AllreduceWorker : public WorkerForTest { public: @@ -50,11 +50,10 @@ class AllreduceWorker : public WorkerForTest { } void BitOr() { - Context ctx; std::vector data(comm_.World(), 0); data[comm_.Rank()] = ~std::uint32_t{0}; auto pcoll = std::shared_ptr{new Coll{}}; - auto rc = pcoll->Allreduce(&ctx, comm_, EraseType(common::Span{data.data(), data.size()}), + auto rc = pcoll->Allreduce(comm_, common::EraseType(common::Span{data.data(), data.size()}), ArrayInterfaceHandler::kU4, Op::kBitwiseOR); ASSERT_TRUE(rc.OK()) << rc.Report(); for (auto v : data) { diff --git a/tests/cpp/collective/test_allreduce.cu b/tests/cpp/collective/test_allreduce.cu new file mode 100644 index 000000000..af9a4e58f --- /dev/null +++ b/tests/cpp/collective/test_allreduce.cu @@ -0,0 +1,70 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#if defined(XGBOOST_USE_NCCL) +#include +#include // for host_vector + +#include "../../../src/collective/coll.h" // for Coll +#include "../../../src/common/common.h" +#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector +#include "../../../src/common/type.h" // for EraseType +#include "../helpers.h" // for MakeCUDACtx +#include "test_worker.cuh" // for NCCLWorkerForTest +#include "test_worker.h" // for WorkerForTest, TestDistributed + +namespace xgboost::collective { +namespace { +class AllreduceTestGPU : public SocketTest {}; + +class Worker : public NCCLWorkerForTest { + public: + using NCCLWorkerForTest::NCCLWorkerForTest; + + void BitOr() { + dh::device_vector data(comm_.World(), 0); + data[comm_.Rank()] = ~std::uint32_t{0}; + auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)), + ArrayInterfaceHandler::kU4, Op::kBitwiseOR); + ASSERT_TRUE(rc.OK()) << rc.Report(); + thrust::host_vector h_data(data.size()); + thrust::copy(data.cbegin(), data.cend(), h_data.begin()); + for (auto v : h_data) { + ASSERT_EQ(v, ~std::uint32_t{0}); + } + } + + void Acc() { + dh::device_vector data(314, 1.5); + auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)), + ArrayInterfaceHandler::kF8, Op::kSum); + ASSERT_TRUE(rc.OK()) << rc.Report(); + for (std::size_t i = 0; i < data.size(); ++i) { + auto v = data[i]; + ASSERT_EQ(v, 1.5 * static_cast(comm_.World())) << i; + } + } +}; +} // namespace + +TEST_F(AllreduceTestGPU, BitOr) { + auto n_workers = common::AllVisibleGPUs(); + TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + Worker w{host, port, timeout, n_workers, r}; + w.Setup(); + w.BitOr(); + }); +} + +TEST_F(AllreduceTestGPU, Sum) { + auto n_workers = common::AllVisibleGPUs(); + TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + Worker w{host, port, timeout, n_workers, r}; + w.Setup(); + w.Acc(); + }); +} +} // namespace xgboost::collective +#endif // defined(XGBOOST_USE_NCCL) diff --git a/tests/cpp/collective/test_broadcast.cc b/tests/cpp/collective/test_broadcast.cc index 0ade86567..4d0d87e93 100644 --- a/tests/cpp/collective/test_broadcast.cc +++ b/tests/cpp/collective/test_broadcast.cc @@ -47,5 +47,5 @@ TEST_F(BroadcastTest, Basic) { Worker worker{host, port, timeout, n_workers, r}; worker.Run(); }); -} +} // namespace } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_worker.cuh b/tests/cpp/collective/test_worker.cuh new file mode 100644 index 000000000..058f8845e --- /dev/null +++ b/tests/cpp/collective/test_worker.cuh @@ -0,0 +1,32 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#pragma once +#include // for shared_ptr + +#include "../../../src/collective/coll.h" // for Coll +#include "../../../src/collective/comm.h" // for Comm +#include "test_worker.h" +#include "xgboost/context.h" // for Context + +namespace xgboost::collective { +class NCCLWorkerForTest : public WorkerForTest { + protected: + std::shared_ptr coll_; + std::shared_ptr nccl_comm_; + std::shared_ptr nccl_coll_; + Context ctx_; + + public: + using WorkerForTest::WorkerForTest; + + void Setup() { + ctx_ = MakeCUDACtx(comm_.Rank()); + coll_.reset(new Coll{}); + nccl_comm_.reset(this->comm_.MakeCUDAVar(&ctx_, coll_)); + nccl_coll_.reset(coll_->MakeCUDAVar()); + ASSERT_EQ(comm_.World(), nccl_comm_->World()); + ASSERT_EQ(comm_.Rank(), nccl_comm_->Rank()); + } +}; +} // namespace xgboost::collective diff --git a/tests/cpp/collective/test_worker.h b/tests/cpp/collective/test_worker.h index a3d6de875..6578ff142 100644 --- a/tests/cpp/collective/test_worker.h +++ b/tests/cpp/collective/test_worker.h @@ -1,6 +1,7 @@ /** * Copyright 2023, XGBoost Contributors */ +#pragma once #include #include // for seconds From 2cfc90e8dbb3c20afed5d6d575d2a6062ae66a04 Mon Sep 17 00:00:00 2001 From: omahs <73983677+omahs@users.noreply.github.com> Date: Mon, 30 Oct 2023 09:52:12 +0100 Subject: [PATCH 3/5] Fix typos (#9731) --- doc/faq.rst | 10 +++++----- doc/parameter.rst | 2 +- python-package/xgboost/spark/params.py | 2 +- rabit/src/allreduce_base.cc | 4 ++-- rabit/src/allreduce_base.h | 10 +++++----- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/doc/faq.rst b/doc/faq.rst index 51de4bbc8..4fe63076c 100644 --- a/doc/faq.rst +++ b/doc/faq.rst @@ -10,14 +10,14 @@ How to tune parameters See :doc:`Parameter Tuning Guide `. ************************ -Description on the model +Description of the model ************************ See :doc:`Introduction to Boosted Trees `. ******************** I have a big dataset ******************** -XGBoost is designed to be memory efficient. Usually it can handle problems as long as the data fit into your memory. +XGBoost is designed to be memory efficient. Usually it can handle problems as long as the data fits into your memory. This usually means millions of instances. If you are running out of memory, checkout the tutorial page for using :doc:`distributed training ` with one of the many frameworks, or the :doc:`external memory version ` for using external memory. @@ -26,7 +26,7 @@ If you are running out of memory, checkout the tutorial page for using :doc:`dis ********************************** How to handle categorical feature? ********************************** -Visit :doc:`this tutorial ` for a walk through of categorical data handling and some worked examples. +Visit :doc:`this tutorial ` for a walkthrough of categorical data handling and some worked examples. ****************************************************************** Why not implement distributed XGBoost on top of X (Spark, Hadoop)? @@ -37,14 +37,14 @@ The ultimate question will still come back to how to push the limit of each comp and use less resources to complete the task (thus with less communication and chance of failure). To achieve these, we decide to reuse the optimizations in the single node XGBoost and build the distributed version on top of it. -The demand of communication in machine learning is rather simple, in the sense that we can depend on a limited set of APIs (in our case rabit). +The demand for communication in machine learning is rather simple, in the sense that we can depend on a limited set of APIs (in our case rabit). Such design allows us to reuse most of the code, while being portable to major platforms such as Hadoop/Yarn, MPI, SGE. Most importantly, it pushes the limit of the computation resources we can use. **************************************** How can I port a model to my own system? **************************************** -The model and data format of XGBoost is exchangeable, +The model and data format of XGBoost are exchangeable, which means the model trained by one language can be loaded in another. This means you can train the model using R, while running prediction using Java or C++, which are more common in production systems. diff --git a/doc/parameter.rst b/doc/parameter.rst index 1162d6f1f..88a712a5a 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -73,7 +73,7 @@ Parameters for Tree Booster =========================== * ``eta`` [default=0.3, alias: ``learning_rate``] - - Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features, and ``eta`` shrinks the feature weights to make the boosting process more conservative. + - Step size shrinkage used in update to prevent overfitting. After each boosting step, we can directly get the weights of new features, and ``eta`` shrinks the feature weights to make the boosting process more conservative. - range: [0,1] * ``gamma`` [default=0, alias: ``min_split_loss``] diff --git a/python-package/xgboost/spark/params.py b/python-package/xgboost/spark/params.py index 7c3231431..a81f6cd33 100644 --- a/python-package/xgboost/spark/params.py +++ b/python-package/xgboost/spark/params.py @@ -17,7 +17,7 @@ class HasArbitraryParamsDict(Params): Params._dummy(), "arbitrary_params_dict", "arbitrary_params_dict This parameter holds all of the additional parameters which are " - "not exposed as the the XGBoost Spark estimator params but can be recognized by " + "not exposed as the XGBoost Spark estimator params but can be recognized by " "underlying XGBoost library. It is stored as a dictionary.", ) diff --git a/rabit/src/allreduce_base.cc b/rabit/src/allreduce_base.cc index 416801ee2..5cab4ae32 100644 --- a/rabit/src/allreduce_base.cc +++ b/rabit/src/allreduce_base.cc @@ -106,7 +106,7 @@ bool AllreduceBase::Init(int argc, char* argv[]) { } } if (dmlc_role != "worker") { - LOG(FATAL) << "Rabit Module currently only work with dmlc worker"; + LOG(FATAL) << "Rabit Module currently only works with dmlc worker"; } // clear the setting before start reconnection @@ -273,7 +273,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) { return xgboost::collective::Success(); } /*! - * \brief connect to the tracker to fix the the missing links + * \brief connect to the tracker to fix the missing links * this function is also used when the engine start up */ [[nodiscard]] xgboost::collective::Result AllreduceBase::ReConnectLinks(const char *cmd) { diff --git a/rabit/src/allreduce_base.h b/rabit/src/allreduce_base.h index f40754273..7724bf3d5 100644 --- a/rabit/src/allreduce_base.h +++ b/rabit/src/allreduce_base.h @@ -89,7 +89,7 @@ class AllreduceBase : public IEngine { } /*! - * \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf, + * \brief internal Allgather function, each node has a segment of data in the ring of sendrecvbuf, * the data provided by current node k is [slice_begin, slice_end), * the next node's segment must start with slice_end * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments @@ -281,7 +281,7 @@ class AllreduceBase : public IEngine { * this function can not be used together with ReadToRingBuffer * a link can either read into the ring buffer, or existing array * \param max_size maximum size of array - * \return true if it is an successful read, false if there is some error happens, check errno + * \return true if it is a successful read, false if there is some error happens, check errno */ inline ReturnType ReadToArray(void *recvbuf_, size_t max_size) { if (max_size == size_read) return kSuccess; @@ -299,7 +299,7 @@ class AllreduceBase : public IEngine { * \brief write data in array to sock * \param sendbuf_ head of array * \param max_size maximum size of array - * \return true if it is an successful write, false if there is some error happens, check errno + * \return true if it is a successful write, false if there is some error happens, check errno */ inline ReturnType WriteFromArray(const void *sendbuf_, size_t max_size) { const char *p = static_cast(sendbuf_); @@ -333,7 +333,7 @@ class AllreduceBase : public IEngine { */ [[nodiscard]] xgboost::collective::Result ConnectTracker(xgboost::collective::TCPSocket *out) const; /*! - * \brief connect to the tracker to fix the the missing links + * \brief connect to the tracker to fix the missing links * this function is also used when the engine start up * \param cmd possible command to sent to tracker */ @@ -358,7 +358,7 @@ class AllreduceBase : public IEngine { size_t count, ReduceFunction reducer); /*! - * \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure + * \brief broadcast data from root to all nodes, this function can fail, and will return the cause of failure * \param sendrecvbuf_ buffer for both sending and receiving data * \param size the size of the data to be broadcasted * \param root the root worker id to broadcast the data From fa65cf664664b7348f2cf1407ddea1d1adb6d829 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Tue, 31 Oct 2023 01:28:34 +0800 Subject: [PATCH 4/5] [doc] How to configure regarding to stage-level (#9727) --------- Co-authored-by: Jiaming Yuan --- doc/tutorials/spark_estimator.rst | 51 +++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 13 deletions(-) diff --git a/doc/tutorials/spark_estimator.rst b/doc/tutorials/spark_estimator.rst index 44bdd7733..8bd1dcd97 100644 --- a/doc/tutorials/spark_estimator.rst +++ b/doc/tutorials/spark_estimator.rst @@ -87,8 +87,8 @@ XGBoost PySpark GPU support XGBoost PySpark fully supports GPU acceleration. Users are not only able to enable efficient training but also utilize their GPUs for the whole PySpark pipeline including ETL and inference. In below sections, we will walk through an example of training on a -PySpark standalone GPU cluster. To get started, first we need to install some additional -packages, then we can set the ``device`` parameter to ``cuda`` or ``gpu``. +Spark standalone cluster with GPU support. To get started, first we need to install some +additional packages, then we can set the ``device`` parameter to ``cuda`` or ``gpu``. Prepare the necessary packages ============================== @@ -128,7 +128,8 @@ Write your PySpark application ============================== Below snippet is a small example for training xgboost model with PySpark. Notice that we are -using a list of feature names and the additional parameter ``device``: +using a list of feature names instead of vector type as the input. The parameter ``"device=cuda"`` +specifically indicates that the training will be performed on a GPU. .. code-block:: python @@ -163,14 +164,29 @@ using a list of feature names and the additional parameter ``device``: predict_df = model.transform(test_df) predict_df.show() -Like other distributed interfaces, the ```device`` parameter doesn't support specifying ordinal as GPUs are managed by Spark instead of XGBoost (good: ``device=cuda``, bad: ``device=cuda:0``). +Like other distributed interfaces, the ``device`` parameter doesn't support specifying ordinal as GPUs are managed by Spark instead of XGBoost (good: ``device=cuda``, bad: ``device=cuda:0``). + +.. _stage-level-scheduling: Submit the PySpark application ============================== -Assuming you have configured your Spark cluster with GPU support. Otherwise, please +Assuming you have configured the Spark standalone cluster with GPU support. Otherwise, please refer to `spark standalone configuration with GPU support `_. +Starting from XGBoost 2.0.1, stage-level scheduling is automatically enabled. Therefore, +if you are using Spark standalone cluster version 3.4.0 or higher, we strongly recommend +configuring the ``"spark.task.resource.gpu.amount"`` as a fractional value. This will +enable running multiple tasks in parallel during the ETL phase. An example configuration +would be ``"spark.task.resource.gpu.amount=1/spark.executor.cores"``. However, if you are +using a XGBoost version earlier than 2.0.1 or a Spark standalone cluster version below 3.4.0, +you still need to set ``"spark.task.resource.gpu.amount"`` equal to ``"spark.executor.resource.gpu.amount"``. + +.. note:: + + As of now, the stage-level scheduling feature in XGBoost is limited to the Spark standalone cluster mode. + However, we have plans to expand its compatibility to YARN and Kubernetes once Spark 3.5.1 is officially released. + .. code-block:: bash export PYSPARK_DRIVER_PYTHON=python @@ -178,19 +194,21 @@ refer to `spark standalone configuration with GPU support :7077 \ + --conf spark.executor.cores=12 \ + --conf spark.task.cpus=1 \ --conf spark.executor.resource.gpu.amount=1 \ - --conf spark.task.resource.gpu.amount=1 \ + --conf spark.task.resource.gpu.amount=0.08 \ --archives xgboost_env.tar.gz#environment \ xgboost_app.py - -The submit command sends the Python environment created by pip or conda along with the -specification of GPU allocation. We will revisit this command later on. +The above command submits the xgboost pyspark application with the python environment created by pip or conda, +specifying a request for 1 GPU and 12 CPUs per executor. So you can see, a total of 12 tasks per executor will be +executed concurrently during the ETL phase. Model Persistence ================= -Similar to standard PySpark ml estimators, one can persist and reuse the model with ``save` +Similar to standard PySpark ml estimators, one can persist and reuse the model with ``save`` and ``load`` methods: .. code-block:: python @@ -230,8 +248,13 @@ Accelerate the whole pipeline for xgboost pyspark With `RAPIDS Accelerator for Apache Spark `_, you can leverage GPUs to accelerate the whole pipeline (ETL, Train, Transform) for xgboost -pyspark without any Python code change. An example submit command is shown below with -additional spark configurations and dependencies: +pyspark without the need for any code modifications. Likewise, you have the option to configure +the ``"spark.task.resource.gpu.amount"`` setting as a fractional value, enabling a higher +number of tasks to be executed in parallel during the ETL phase. please refer to +:ref:`stage-level-scheduling` for more details. + + +An example submit command is shown below with additional spark configurations and dependencies: .. code-block:: bash @@ -240,8 +263,10 @@ additional spark configurations and dependencies: spark-submit \ --master spark://:7077 \ + --conf spark.executor.cores=12 \ + --conf spark.task.cpus=1 \ --conf spark.executor.resource.gpu.amount=1 \ - --conf spark.task.resource.gpu.amount=1 \ + --conf spark.task.resource.gpu.amount=0.08 \ --packages com.nvidia:rapids-4-spark_2.12:23.04.0 \ --conf spark.plugins=com.nvidia.spark.SQLPlugin \ --conf spark.sql.execution.arrow.maxRecordsPerBatch=1000000 \ From 80390e6cb69c4f478217179bede451f7f8dd0b56 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 31 Oct 2023 02:39:55 +0800 Subject: [PATCH 5/5] [coll] Federated comm. (#9732) --- plugin/federated/CMakeLists.txt | 2 +- plugin/federated/federated_comm.cc | 114 ++++++++++++++++++ plugin/federated/federated_comm.h | 53 ++++++++ plugin/federated/federated_server.cc | 7 +- plugin/federated/federated_server.h | 20 +-- plugin/federated/federated_tracker.cc | 101 ++++++++++++++++ plugin/federated/federated_tracker.h | 41 +++++++ src/collective/tracker.h | 2 +- tests/cpp/common/test_bitfield.cc | 25 ++++ tests/cpp/helpers.h | 27 +++++ .../plugin/federated/test_federated_comm.cc | 84 +++++++++++++ tests/cpp/plugin/federated/test_worker.h | 42 +++++++ tests/cpp/plugin/helpers.h | 6 +- 13 files changed, 508 insertions(+), 16 deletions(-) create mode 100644 plugin/federated/federated_comm.cc create mode 100644 plugin/federated/federated_comm.h create mode 100644 plugin/federated/federated_tracker.cc create mode 100644 plugin/federated/federated_tracker.h create mode 100644 tests/cpp/plugin/federated/test_federated_comm.cc create mode 100644 tests/cpp/plugin/federated/test_worker.h diff --git a/plugin/federated/CMakeLists.txt b/plugin/federated/CMakeLists.txt index be854d755..7c2cfa6fb 100644 --- a/plugin/federated/CMakeLists.txt +++ b/plugin/federated/CMakeLists.txt @@ -28,6 +28,6 @@ target_sources(federated_client INTERFACE federated_client.h) target_link_libraries(federated_client INTERFACE federated_proto) # Rabit engine for Federated Learning. -target_sources(objxgboost PRIVATE federated_server.cc) +target_sources(objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc) target_link_libraries(objxgboost PRIVATE federated_client "-Wl,--exclude-libs,ALL") target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1) diff --git a/plugin/federated/federated_comm.cc b/plugin/federated/federated_comm.cc new file mode 100644 index 000000000..4b51fd52d --- /dev/null +++ b/plugin/federated/federated_comm.cc @@ -0,0 +1,114 @@ +/** + * Copyright 2023, XGBoost contributors + */ +#include "federated_comm.h" + +#include + +#include // for int32_t +#include // for getenv +#include // for string, stoi + +#include "../../src/common/common.h" // for Split +#include "../../src/common/json_utils.h" // for OptionalArg +#include "xgboost/json.h" // for Json +#include "xgboost/logging.h" + +namespace xgboost::collective { +void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_t world, + std::int32_t rank, std::string const& server_cert, + std::string const& client_key, std::string const& client_cert) { + this->rank_ = rank; + this->world_ = world; + + this->tracker_.host = host; + this->tracker_.port = port; + this->tracker_.rank = rank; + + CHECK_GE(world, 1) << "Invalid world size."; + CHECK_GE(rank, 0) << "Invalid worker rank."; + CHECK_LT(rank, world) << "Invalid worker rank."; + + if (server_cert.empty()) { + stub_ = [&] { + grpc::ChannelArguments args; + args.SetMaxReceiveMessageSize(std::numeric_limits::max()); + return federated::Federated::NewStub( + grpc::CreateCustomChannel(host, grpc::InsecureChannelCredentials(), args)); + }(); + } else { + stub_ = [&] { + grpc::SslCredentialsOptions options; + options.pem_root_certs = server_cert; + options.pem_private_key = client_key; + options.pem_cert_chain = client_cert; + grpc::ChannelArguments args; + args.SetMaxReceiveMessageSize(std::numeric_limits::max()); + auto channel = grpc::CreateCustomChannel(host, grpc::SslCredentials(options), args); + channel->WaitForConnected( + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN))); + return federated::Federated::NewStub(channel); + }(); + } +} + +FederatedComm::FederatedComm(Json const& config) { + /** + * Topology + */ + std::string server_address{}; + std::int32_t world_size{0}; + std::int32_t rank{-1}; + // Parse environment variables first. + auto* value = std::getenv("FEDERATED_SERVER_ADDRESS"); + if (value != nullptr) { + server_address = value; + } + value = std::getenv("FEDERATED_WORLD_SIZE"); + if (value != nullptr) { + world_size = std::stoi(value); + } + value = std::getenv("FEDERATED_RANK"); + if (value != nullptr) { + rank = std::stoi(value); + } + + server_address = OptionalArg(config, "federated_server_address", server_address); + world_size = + OptionalArg(config, "federated_world_size", static_cast(world_size)); + rank = OptionalArg(config, "federated_rank", static_cast(rank)); + + auto parsed = common::Split(server_address, ':'); + CHECK_EQ(parsed.size(), 2) << "invalid server address:" << server_address; + + CHECK_NE(rank, -1) << "Parameter `federated_rank` is required"; + CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required."; + CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required."; + + /** + * Certificates + */ + std::string server_cert{}; + std::string client_key{}; + std::string client_cert{}; + value = getenv("FEDERATED_SERVER_CERT_PATH"); + if (value != nullptr) { + server_cert = value; + } + value = getenv("FEDERATED_CLIENT_KEY_PATH"); + if (value != nullptr) { + client_key = value; + } + value = getenv("FEDERATED_CLIENT_CERT_PATH"); + if (value != nullptr) { + client_cert = value; + } + + server_cert = OptionalArg(config, "federated_server_cert_path", server_cert); + client_key = OptionalArg(config, "federated_client_key_path", client_key); + client_cert = OptionalArg(config, "federated_client_cert_path", client_cert); + + this->Init(parsed[0], std::stoi(parsed[1]), world_size, rank, server_cert, client_key, + client_cert); +} +} // namespace xgboost::collective diff --git a/plugin/federated/federated_comm.h b/plugin/federated/federated_comm.h new file mode 100644 index 000000000..8e6fe7d67 --- /dev/null +++ b/plugin/federated/federated_comm.h @@ -0,0 +1,53 @@ +/** + * Copyright 2023, XGBoost contributors + */ +#pragma once + +#include +#include + +#include // for int32_t +#include // for unique_ptr +#include // for string + +#include "../../src/collective/comm.h" // for Comm +#include "../../src/common/json_utils.h" // for OptionalArg +#include "xgboost/json.h" + +namespace xgboost::collective { +class FederatedComm : public Comm { + std::unique_ptr stub_; + + void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank, + std::string const& server_cert, std::string const& client_key, + std::string const& client_cert); + + public: + /** + * @param config + * + * - federated_server_address: Tracker address + * - federated_world_size: The number of workers + * - federated_rank: Rank of federated worker + * - federated_server_cert_path + * - federated_client_key_path + * - federated_client_cert_path + */ + explicit FederatedComm(Json const& config); + explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world, + std::int32_t rank) { + this->Init(host, port, world, rank, {}, {}, {}); + } + ~FederatedComm() override { stub_.reset(); } + + [[nodiscard]] std::shared_ptr Chan(std::int32_t) const override { + LOG(FATAL) << "peer to peer communication is not allowed for federated learning."; + return nullptr; + } + [[nodiscard]] Result LogTracker(std::string msg) const override { + LOG(CONSOLE) << msg; + return Success(); + } + [[nodiscard]] bool IsFederated() const override { return true; } +}; +} // namespace xgboost::collective diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc index ad6cf6022..9dd97c2e1 100644 --- a/plugin/federated/federated_server.cc +++ b/plugin/federated/federated_server.cc @@ -4,12 +4,15 @@ #include "federated_server.h" #include +#include // for Server #include #include #include +#include "../../src/collective/comm.h" #include "../../src/common/io.h" +#include "../../src/common/json_utils.h" namespace xgboost::federated { grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request, @@ -46,7 +49,7 @@ grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest void RunServer(int port, std::size_t world_size, char const* server_key_file, char const* server_cert_file, char const* client_cert_file) { std::string const server_address = "0.0.0.0:" + std::to_string(port); - FederatedService service{world_size}; + FederatedService service{static_cast(world_size)}; grpc::ServerBuilder builder; auto options = @@ -68,7 +71,7 @@ void RunServer(int port, std::size_t world_size, char const* server_key_file, void RunInsecureServer(int port, std::size_t world_size) { std::string const server_address = "0.0.0.0:" + std::to_string(port); - FederatedService service{world_size}; + FederatedService service{static_cast(world_size)}; grpc::ServerBuilder builder; builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); diff --git a/plugin/federated/federated_server.h b/plugin/federated/federated_server.h index 711ef5588..20f3149f9 100644 --- a/plugin/federated/federated_server.h +++ b/plugin/federated/federated_server.h @@ -1,18 +1,22 @@ -/*! - * Copyright 2022 XGBoost contributors +/** + * Copyright 2022-2023, XGBoost contributors */ #pragma once #include +#include // for int32_t +#include // for future + #include "../../src/collective/in_memory_handler.h" +#include "../../src/collective/tracker.h" // for Tracker +#include "xgboost/collective/result.h" // for Result -namespace xgboost { -namespace federated { - +namespace xgboost::federated { class FederatedService final : public Federated::Service { public: - explicit FederatedService(std::size_t const world_size) : handler_{world_size} {} + explicit FederatedService(std::int32_t world_size) + : handler_{static_cast(world_size)} {} grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, AllgatherReply* reply) override; @@ -34,6 +38,4 @@ void RunServer(int port, std::size_t world_size, char const* server_key_file, char const* server_cert_file, char const* client_cert_file); void RunInsecureServer(int port, std::size_t world_size); - -} // namespace federated -} // namespace xgboost +} // namespace xgboost::federated diff --git a/plugin/federated/federated_tracker.cc b/plugin/federated/federated_tracker.cc new file mode 100644 index 000000000..3dad9d7ce --- /dev/null +++ b/plugin/federated/federated_tracker.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2022-2023, XGBoost contributors + */ +#include "federated_tracker.h" + +#include // for InsecureServerCredentials, ... +#include // for ServerBuilder + +#include // for ms +#include // for int32_t +#include // for exception +#include // for numeric_limits +#include // for string +#include // for sleep_for + +#include "../../src/common/io.h" // for ReadAll +#include "../../src/common/json_utils.h" // for RequiredArg +#include "../../src/common/timer.h" // for Timer +#include "federated_server.h" // for FederatedService + +namespace xgboost::collective { +FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} { + auto is_secure = RequiredArg(config, "federated_secure", __func__); + if (is_secure) { + server_key_path_ = RequiredArg(config, "server_key_path", __func__); + server_cert_file_ = RequiredArg(config, "server_cert_path", __func__); + client_cert_file_ = RequiredArg(config, "client_cert_path", __func__); + } +} + +std::future FederatedTracker::Run() { + return std::async([this]() { + std::string const server_address = "0.0.0.0:" + std::to_string(this->port_); + federated::FederatedService service{static_cast(this->n_workers_)}; + grpc::ServerBuilder builder; + + if (this->server_cert_file_.empty()) { + builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); + if (this->port_ == 0) { + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_); + } else { + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + } + builder.RegisterService(&service); + server_ = builder.BuildAndStart(); + LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size " + << this->n_workers_; + } else { + auto options = grpc::SslServerCredentialsOptions( + GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); + options.pem_root_certs = xgboost::common::ReadAll(client_cert_file_); + auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair(); + key.private_key = xgboost::common::ReadAll(server_key_path_); + key.cert_chain = xgboost::common::ReadAll(server_cert_file_); + options.pem_key_cert_pairs.push_back(key); + builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); + if (this->port_ == 0) { + builder.AddListeningPort(server_address, grpc::SslServerCredentials(options), &port_); + } else { + builder.AddListeningPort(server_address, grpc::SslServerCredentials(options)); + } + builder.RegisterService(&service); + server_ = builder.BuildAndStart(); + LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size " + << n_workers_; + } + + try { + server_->Wait(); + } catch (std::exception const& e) { + return collective::Fail(std::string{e.what()}); + } + return collective::Success(); + }); +} + +FederatedTracker::~FederatedTracker() = default; + +Result FederatedTracker::Shutdown() { + common::Timer timer; + timer.Start(); + using namespace std::chrono_literals; + while (!server_) { + timer.Stop(); + auto ela = timer.ElapsedSeconds(); + if (ela > this->Timeout().count()) { + return Fail("Failed to shutdown, timeout:" + std::to_string(this->Timeout().count()) + + " seconds."); + } + std::this_thread::sleep_for(10ms); + } + + try { + server_->Shutdown(); + } catch (std::exception const& e) { + return Fail("Failed to shutdown:" + std::string{e.what()}); + } + + return Success(); +} +} // namespace xgboost::collective diff --git a/plugin/federated/federated_tracker.h b/plugin/federated/federated_tracker.h new file mode 100644 index 000000000..9043adb38 --- /dev/null +++ b/plugin/federated/federated_tracker.h @@ -0,0 +1,41 @@ +/** + * Copyright 2022-2023, XGBoost contributors + */ +#pragma once +#include // for Server + +#include // for future +#include // for unique_ptr +#include // for string + +#include "../../src/collective/tracker.h" // for Tracker +#include "xgboost/collective/result.h" // for Result +#include "xgboost/json.h" // for Json + +namespace xgboost::collective { +class FederatedTracker : public collective::Tracker { + std::unique_ptr server_; + std::string server_key_path_; + std::string server_cert_file_; + std::string client_cert_file_; + + public: + /** + * @brief CTOR + * + * @param config Configuration, other than the base configuration from Tracker, we have: + * + * - federated_secure: bool whether this is a secure server. + * - server_key_path: path to the key. + * - server_cert_path: certificate path. + * - client_cert_path: certificate path for client. + */ + explicit FederatedTracker(Json const& config); + ~FederatedTracker() override; + std::future Run() override; + // federated tracker do not provide initialization parameters, users have to provide it + // themseleves. + [[nodiscard]] Json WorkerArgs() const override { return Json{Null{}}; } + [[nodiscard]] Result Shutdown(); +}; +} // namespace xgboost::collective diff --git a/src/collective/tracker.h b/src/collective/tracker.h index 7bbee3c8d..f90373220 100644 --- a/src/collective/tracker.h +++ b/src/collective/tracker.h @@ -50,6 +50,7 @@ class Tracker { [[nodiscard]] virtual std::future Run() = 0; [[nodiscard]] virtual Json WorkerArgs() const = 0; [[nodiscard]] std::chrono::seconds Timeout() const { return timeout_; } + [[nodiscard]] virtual std::int32_t Port() const { return port_; } }; class RabitTracker : public Tracker { @@ -124,7 +125,6 @@ class RabitTracker : public Tracker { std::future Run() override; - [[nodiscard]] std::int32_t Port() const { return port_; } [[nodiscard]] Json WorkerArgs() const override { Json args{Object{}}; args["DMLC_TRACKER_URI"] = String{host_}; diff --git a/tests/cpp/common/test_bitfield.cc b/tests/cpp/common/test_bitfield.cc index 902e69f85..564776642 100644 --- a/tests/cpp/common/test_bitfield.cc +++ b/tests/cpp/common/test_bitfield.cc @@ -97,4 +97,29 @@ TEST(BitField, Clear) { TestBitFieldClear(19); } } + +TEST(BitField, CTZ) { + { + auto cnt = TrailingZeroBits(0); + ASSERT_EQ(cnt, sizeof(std::uint32_t) * 8); + } + { + auto cnt = TrailingZeroBits(0b00011100); + ASSERT_EQ(cnt, 2); + cnt = detail::TrailingZeroBitsImpl(0b00011100); + ASSERT_EQ(cnt, 2); + } + { + auto cnt = TrailingZeroBits(0b00011101); + ASSERT_EQ(cnt, 0); + cnt = detail::TrailingZeroBitsImpl(0b00011101); + ASSERT_EQ(cnt, 0); + } + { + auto cnt = TrailingZeroBits(0b1000000000000000); + ASSERT_EQ(cnt, 15); + cnt = detail::TrailingZeroBitsImpl(0b1000000000000000); + ASSERT_EQ(cnt, 15); + } +} } // namespace xgboost diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 82a55450e..9adda8aed 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -572,4 +572,31 @@ class BaseMGPUTest : public ::testing::Test { class DeclareUnifiedDistributedTest(MetricTest) : public BaseMGPUTest{}; inline DeviceOrd FstCU() { return DeviceOrd::CUDA(0); } + +/** + * @brief poor man's gmock for message matching. + * + * @tparam Error The type of expected execption. + * + * @param submsg A substring of the actual error message. + * @param fn The function that throws Error + */ +template +void ExpectThrow(std::string submsg, Fn&& fn) { + try { + fn(); + } catch (Error const& exc) { + auto actual = std::string{exc.what()}; + ASSERT_NE(actual.find(submsg), std::string::npos) + << "Expecting substring `" << submsg << "` from the error message." + << " Got:\n" + << actual << "\n"; + return; + } catch (std::exception const& exc) { + auto actual = exc.what(); + ASSERT_TRUE(false) << "An unexpected type of exception is thrown. what:" << actual; + return; + } + ASSERT_TRUE(false) << "No exception is thrown"; +} } // namespace xgboost diff --git a/tests/cpp/plugin/federated/test_federated_comm.cc b/tests/cpp/plugin/federated/test_federated_comm.cc new file mode 100644 index 000000000..5bbde1bbb --- /dev/null +++ b/tests/cpp/plugin/federated/test_federated_comm.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2022-2023, XGBoost contributors + */ +#include + +#include // for string +#include // for thread + +#include "../../../../plugin/federated/federated_comm.h" +#include "../../collective/net_test.h" // for SocketTest +#include "../../helpers.h" // for ExpectThrow +#include "test_worker.h" // for TestFederated +#include "xgboost/json.h" // for Json + +namespace xgboost::collective { +namespace { +class FederatedCommTest : public SocketTest {}; +} // namespace + +TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) { + auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; }; + ExpectThrow("Invalid world size.", construct); +} + +TEST_F(FederatedCommTest, ThrowOnRankTooSmall) { + auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; }; + ExpectThrow("Invalid worker rank.", construct); +} + +TEST_F(FederatedCommTest, ThrowOnRankTooBig) { + auto construct = [] { FederatedComm comm{"localhost", 0, 1, 1}; }; + ExpectThrow("Invalid worker rank.", construct); +} + +TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) { + auto construct = [] { + Json config{Object{}}; + config["federated_server_address"] = std::string("localhost:0"); + config["federated_world_size"] = std::string("1"); + config["federated_rank"] = Integer(0); + FederatedComm comm(config); + }; + ExpectThrow("got: `String`", construct); +} + +TEST_F(FederatedCommTest, ThrowOnRankNotInteger) { + auto construct = [] { + Json config{Object{}}; + config["federated_server_address"] = std::string("localhost:0"); + config["federated_world_size"] = 1; + config["federated_rank"] = std::string("0"); + FederatedComm comm(config); + }; + ExpectThrow("got: `String`", construct); +} + +TEST_F(FederatedCommTest, GetWorldSizeAndRank) { + Json config{Object{}}; + config["federated_world_size"] = 6; + config["federated_rank"] = 3; + config["federated_server_address"] = String{"localhost:0"}; + FederatedComm comm{config}; + EXPECT_EQ(comm.World(), 6); + EXPECT_EQ(comm.Rank(), 3); +} + +TEST_F(FederatedCommTest, IsDistributed) { + FederatedComm comm{"localhost", 0, 2, 1}; + EXPECT_TRUE(comm.IsDistributed()); +} + +TEST_F(FederatedCommTest, InsecureTracker) { + std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u); + TestFederated(n_workers, [=](std::int32_t port, std::int32_t rank) { + Json config{Object{}}; + config["federated_world_size"] = n_workers; + config["federated_rank"] = rank; + config["federated_server_address"] = "0.0.0.0:" + std::to_string(port); + FederatedComm comm{config}; + ASSERT_EQ(comm.Rank(), rank); + ASSERT_EQ(comm.World(), n_workers); + }); +} +} // namespace xgboost::collective diff --git a/tests/cpp/plugin/federated/test_worker.h b/tests/cpp/plugin/federated/test_worker.h new file mode 100644 index 000000000..719b4c343 --- /dev/null +++ b/tests/cpp/plugin/federated/test_worker.h @@ -0,0 +1,42 @@ +/** + * Copyright 2022-2023, XGBoost contributors + */ +#pragma once + +#include + +#include // for ms +#include // for thread + +#include "../../../../plugin/federated/federated_tracker.h" +#include "xgboost/json.h" // for Json + +namespace xgboost::collective { +template +void TestFederated(std::int32_t n_workers, WorkerFn&& fn) { + Json config{Object()}; + config["federated_secure"] = Boolean{false}; + config["n_workers"] = Integer{n_workers}; + FederatedTracker tracker{config}; + auto fut = tracker.Run(); + + std::vector workers; + using namespace std::chrono_literals; + while (tracker.Port() == 0) { + std::this_thread::sleep_for(100ms); + } + std::int32_t port = tracker.Port(); + + for (std::int32_t i = 0; i < n_workers; ++i) { + workers.emplace_back([=] { fn(port, i); }); + } + + for (auto& t : workers) { + t.join(); + } + + auto rc = tracker.Shutdown(); + ASSERT_TRUE(rc.OK()) << rc.Report(); + ASSERT_TRUE(fut.get().OK()); +} +} // namespace xgboost::collective diff --git a/tests/cpp/plugin/helpers.h b/tests/cpp/plugin/helpers.h index b756adefd..3dd0c3a1f 100644 --- a/tests/cpp/plugin/helpers.h +++ b/tests/cpp/plugin/helpers.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2022-2023 XGBoost contributors +/** + * Copyright 2022-2023, XGBoost contributors */ #pragma once @@ -26,7 +26,7 @@ class ServerForTest { explicit ServerForTest(std::size_t world_size) { server_thread_.reset(new std::thread([this, world_size] { grpc::ServerBuilder builder; - xgboost::federated::FederatedService service{world_size}; + xgboost::federated::FederatedService service{static_cast(world_size)}; int selected_port; builder.AddListeningPort("localhost:0", grpc::InsecureServerCredentials(), &selected_port); builder.RegisterService(&service);