[coll] Add nccl. (#9726)

This commit is contained in:
Jiaming Yuan 2023-10-28 16:33:58 +08:00 committed by GitHub
parent 0c621094b3
commit 6755179e77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 924 additions and 111 deletions

View File

@ -7,20 +7,23 @@
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include <cstdint> // for int8_t, int32_t, int64_t #include <cstdint> // for int8_t, int32_t, int64_t
#include <memory> // for shared_ptr #include <memory> // for shared_ptr
#include <numeric> // for partial_sum
#include <vector> // for vector
#include "broadcast.h"
#include "comm.h" // for Comm, Channel #include "comm.h" // for Comm, Channel
#include "xgboost/collective/result.h" // for Result #include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span #include "xgboost/span.h" // for Span
namespace xgboost::collective::cpu_impl { namespace xgboost::collective {
namespace cpu_impl {
Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size_t segment_size, Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size_t segment_size,
std::int32_t worker_off, std::shared_ptr<Channel> prev_ch, std::int32_t worker_off, std::shared_ptr<Channel> prev_ch,
std::shared_ptr<Channel> next_ch) { std::shared_ptr<Channel> next_ch) {
auto world = comm.World(); auto world = comm.World();
auto rank = comm.Rank(); auto rank = comm.Rank();
CHECK_LT(worker_off, world); CHECK_LT(worker_off, world);
if (world == 1) {
return Success();
}
for (std::int32_t r = 0; r < world; ++r) { for (std::int32_t r = 0; r < world; ++r) {
auto send_rank = (rank + world - r + worker_off) % world; auto send_rank = (rank + world - r + worker_off) % world;
@ -43,11 +46,29 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
return Success(); return Success();
} }
Result BroadcastAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t> recv) {
std::size_t offset = 0;
for (std::int32_t r = 0; r < comm.World(); ++r) {
auto as_bytes = sizes[r];
auto rc = Broadcast(comm, recv.subspan(offset, as_bytes), r);
if (!rc.OK()) {
return rc;
}
offset += as_bytes;
}
return Success();
}
} // namespace cpu_impl
namespace detail {
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes, [[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t const> data, common::Span<std::int64_t const> offset,
common::Span<std::int64_t> offset,
common::Span<std::int8_t> erased_result) { common::Span<std::int8_t> erased_result) {
auto world = comm.World(); auto world = comm.World();
if (world == 1) {
return Success();
}
auto rank = comm.Rank(); auto rank = comm.Rank();
auto prev = BootstrapPrev(rank, comm.World()); auto prev = BootstrapPrev(rank, comm.World());
@ -56,17 +77,6 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
auto prev_ch = comm.Chan(prev); auto prev_ch = comm.Chan(prev);
auto next_ch = comm.Chan(next); auto next_ch = comm.Chan(next);
// get worker offset
CHECK_EQ(world + 1, offset.size());
std::fill_n(offset.data(), offset.size(), 0);
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
CHECK_EQ(*offset.cbegin(), 0);
// copy data
auto current = erased_result.subspan(offset[rank], data.size_bytes());
auto erased_data = EraseType(data);
std::copy_n(erased_data.data(), erased_data.size(), current.data());
for (std::int32_t r = 0; r < world; ++r) { for (std::int32_t r = 0; r < world; ++r) {
auto send_rank = (rank + world - r) % world; auto send_rank = (rank + world - r) % world;
auto send_off = offset[send_rank]; auto send_off = offset[send_rank];
@ -87,4 +97,5 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
} }
return comm.Block(); return comm.Block();
} }
} // namespace xgboost::collective::cpu_impl } // namespace detail
} // namespace xgboost::collective

View File

@ -9,28 +9,47 @@
#include <type_traits> // for remove_cv_t #include <type_traits> // for remove_cv_t
#include <vector> // for vector #include <vector> // for vector
#include "../common/type.h" // for EraseType #include "../common/type.h" // for EraseType
#include "comm.h" // for Comm, Channel #include "comm.h" // for Comm, Channel
#include "xgboost/collective/result.h" // for Result #include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span #include "xgboost/linalg.h"
#include "xgboost/span.h" // for Span
namespace xgboost::collective { namespace xgboost::collective {
namespace cpu_impl { namespace cpu_impl {
/** /**
* @param worker_off Segment offset. For example, if the rank 2 worker specifis worker_off * @param worker_off Segment offset. For example, if the rank 2 worker specifies
* = 1, then it owns the third segment. * worker_off = 1, then it owns the third segment.
*/ */
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, [[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data,
std::size_t segment_size, std::int32_t worker_off, std::size_t segment_size, std::int32_t worker_off,
std::shared_ptr<Channel> prev_ch, std::shared_ptr<Channel> prev_ch,
std::shared_ptr<Channel> next_ch); std::shared_ptr<Channel> next_ch);
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes, /**
common::Span<std::int8_t const> data, * @brief Implement allgather-v using broadcast.
common::Span<std::int64_t> offset, *
common::Span<std::int8_t> erased_result); * https://arxiv.org/abs/1812.05964
*/
Result BroadcastAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t> recv);
} // namespace cpu_impl } // namespace cpu_impl
namespace detail {
inline void AllgatherVOffset(common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> offset) {
// get worker offset
std::fill_n(offset.data(), offset.size(), 0);
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
CHECK_EQ(*offset.cbegin(), 0);
}
// An implementation that's used by both cpu and gpu
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int64_t const> offset,
common::Span<std::int8_t> erased_result);
} // namespace detail
template <typename T> template <typename T>
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) { [[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
auto n_bytes = sizeof(T) * size; auto n_bytes = sizeof(T) * size;
@ -68,9 +87,15 @@ template <typename T>
auto h_result = common::Span{result.data(), result.size()}; auto h_result = common::Span{result.data(), result.size()};
auto erased_result = common::EraseType(h_result); auto erased_result = common::EraseType(h_result);
auto erased_data = common::EraseType(data); auto erased_data = common::EraseType(data);
std::vector<std::int64_t> offset(world + 1); std::vector<std::int64_t> recv_segments(world + 1);
auto s_segments = common::Span{recv_segments.data(), recv_segments.size()};
return cpu_impl::RingAllgatherV(comm, sizes, erased_data, // get worker offset
common::Span{offset.data(), offset.size()}, erased_result); detail::AllgatherVOffset(sizes, s_segments);
// copy data
auto current = erased_result.subspan(recv_segments[rank], data.size_bytes());
std::copy_n(erased_data.data(), erased_data.size(), current.data());
return detail::RingAllgatherV(comm, sizes, s_segments, erased_result);
} }
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -8,16 +8,14 @@
#include <cstdint> // for int8_t, int64_t #include <cstdint> // for int8_t, int64_t
#include <functional> // for bit_and, bit_or, bit_xor, plus #include <functional> // for bit_and, bit_or, bit_xor, plus
#include "allgather.h" // for RingAllgatherV, RingAllgather #include "allgather.h" // for RingAllgatherV, RingAllgather
#include "allreduce.h" // for Allreduce #include "allreduce.h" // for Allreduce
#include "broadcast.h" // for Broadcast #include "broadcast.h" // for Broadcast
#include "comm.h" // for Comm #include "comm.h" // for Comm
#include "xgboost/context.h" // for Context
namespace xgboost::collective { namespace xgboost::collective {
[[nodiscard]] Result Coll::Allreduce(Context const*, Comm const& comm, [[nodiscard]] Result Coll::Allreduce(Comm const& comm, common::Span<std::int8_t> data,
common::Span<std::int8_t> data, ArrayInterfaceHandler::Type, ArrayInterfaceHandler::Type, Op op) {
Op op) {
namespace coll = ::xgboost::collective; namespace coll = ::xgboost::collective;
auto redop_fn = [](auto lhs, auto out, auto elem_op) { auto redop_fn = [](auto lhs, auto out, auto elem_op) {
@ -55,21 +53,45 @@ namespace xgboost::collective {
return comm.Block(); return comm.Block();
} }
[[nodiscard]] Result Coll::Broadcast(Context const*, Comm const& comm, [[nodiscard]] Result Coll::Broadcast(Comm const& comm, common::Span<std::int8_t> data,
common::Span<std::int8_t> data, std::int32_t root) { std::int32_t root) {
return cpu_impl::Broadcast(comm, data, root); return cpu_impl::Broadcast(comm, data, root);
} }
[[nodiscard]] Result Coll::Allgather(Context const*, Comm const& comm, [[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span<std::int8_t> data,
common::Span<std::int8_t> data, std::size_t size) { std::int64_t size) {
return RingAllgather(comm, data, size); return RingAllgather(comm, data, size);
} }
[[nodiscard]] Result Coll::AllgatherV(Context const*, Comm const& comm, [[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes, common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments, common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv) { common::Span<std::int8_t> recv, AllgatherVAlgo algo) {
return cpu_impl::RingAllgatherV(comm, sizes, data, recv_segments, recv); // get worker offset
detail::AllgatherVOffset(sizes, recv_segments);
// copy data
auto current = recv.subspan(recv_segments[comm.Rank()], data.size_bytes());
if (current.data() != data.data()) {
std::copy_n(data.data(), data.size(), current.data());
}
switch (algo) {
case AllgatherVAlgo::kRing:
return detail::RingAllgatherV(comm, sizes, recv_segments, recv);
case AllgatherVAlgo::kBcast:
return cpu_impl::BroadcastAllgatherV(comm, sizes, recv);
default: {
return Fail("Unknown algorithm for allgather-v");
}
}
} }
#if !defined(XGBOOST_USE_NCCL)
Coll* Coll::MakeCUDAVar() {
LOG(FATAL) << "NCCL is required for device communication.";
return nullptr;
}
#endif
} // namespace xgboost::collective } // namespace xgboost::collective

254
src/collective/coll.cu Normal file
View File

@ -0,0 +1,254 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#if defined(XGBOOST_USE_NCCL)
#include <cstdint> // for int8_t, int64_t
#include "../common/cuda_context.cuh"
#include "../common/device_helpers.cuh"
#include "../data/array_interface.h"
#include "allgather.h" // for AllgatherVOffset
#include "coll.cuh"
#include "comm.cuh"
#include "nccl.h"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
Coll* Coll::MakeCUDAVar() { return new NCCLColl{}; }
NCCLColl::~NCCLColl() = default;
namespace {
Result GetNCCLResult(ncclResult_t code) {
if (code == ncclSuccess) {
return Success();
}
std::stringstream ss;
ss << "NCCL failure: " << ncclGetErrorString(code) << ".";
if (code == ncclUnhandledCudaError) {
// nccl usually preserves the last error so we can get more details.
auto err = cudaPeekAtLastError();
ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n";
} else if (code == ncclSystemError) {
ss << " This might be caused by a network configuration issue. Please consider specifying "
"the network interface for NCCL via environment variables listed in its reference: "
"`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n";
}
return Fail(ss.str());
}
auto GetNCCLType(ArrayInterfaceHandler::Type type) {
auto fatal = [] {
LOG(FATAL) << "Invalid type for NCCL operation.";
return ncclHalf; // dummy return to silent the compiler warning.
};
using H = ArrayInterfaceHandler;
switch (type) {
case H::kF2:
return ncclHalf;
case H::kF4:
return ncclFloat32;
case H::kF8:
return ncclFloat64;
case H::kF16:
return fatal();
case H::kI1:
return ncclInt8;
case H::kI2:
return fatal();
case H::kI4:
return ncclInt32;
case H::kI8:
return ncclInt64;
case H::kU1:
return ncclUint8;
case H::kU2:
return fatal();
case H::kU4:
return ncclUint32;
case H::kU8:
return ncclUint64;
}
return fatal();
}
bool IsBitwiseOp(Op const& op) {
return op == Op::kBitwiseAND || op == Op::kBitwiseOR || op == Op::kBitwiseXOR;
}
template <typename Func>
void RunBitwiseAllreduce(dh::CUDAStreamView stream, common::Span<std::int8_t> out_buffer,
std::int8_t const* device_buffer, Func func, std::int32_t world_size,
std::size_t size) {
dh::LaunchN(size, stream, [=] __device__(std::size_t idx) {
auto result = device_buffer[idx];
for (auto rank = 1; rank < world_size; rank++) {
result = func(result, device_buffer[rank * size + idx]);
}
out_buffer[idx] = result;
});
}
[[nodiscard]] Result BitwiseAllReduce(NCCLComm const* pcomm, ncclComm_t handle,
common::Span<std::int8_t> data, Op op) {
dh::device_vector<std::int8_t> buffer(data.size() * pcomm->World());
auto* device_buffer = buffer.data().get();
// First gather data from all the workers.
CHECK(handle);
auto rc = GetNCCLResult(
ncclAllGather(data.data(), device_buffer, data.size(), ncclInt8, handle, pcomm->Stream()));
if (!rc.OK()) {
return rc;
}
// Then reduce locally.
switch (op) {
case Op::kBitwiseAND:
RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_and<std::int8_t>(),
pcomm->World(), data.size());
break;
case Op::kBitwiseOR:
RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_or<std::int8_t>(),
pcomm->World(), data.size());
break;
case Op::kBitwiseXOR:
RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_xor<std::int8_t>(),
pcomm->World(), data.size());
break;
default:
LOG(FATAL) << "Not a bitwise reduce operation.";
}
return Success();
}
ncclRedOp_t GetNCCLRedOp(Op const& op) {
ncclRedOp_t result{ncclMax};
switch (op) {
case Op::kMax:
result = ncclMax;
break;
case Op::kMin:
result = ncclMin;
break;
case Op::kSum:
result = ncclSum;
break;
default:
LOG(FATAL) << "Unsupported reduce operation.";
}
return result;
}
} // namespace
[[nodiscard]] Result NCCLColl::Allreduce(Comm const& comm, common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type type, Op op) {
if (!comm.IsDistributed()) {
return Success();
}
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
CHECK(nccl);
return Success() << [&] {
if (IsBitwiseOp(op)) {
return BitwiseAllReduce(nccl, nccl->Handle(), data, op);
} else {
return DispatchDType(type, [=](auto t) {
using T = decltype(t);
auto rdata = common::RestoreType<T>(data);
auto rc = ncclAllReduce(data.data(), data.data(), rdata.size(), GetNCCLType(type),
GetNCCLRedOp(op), nccl->Handle(), nccl->Stream());
return GetNCCLResult(rc);
});
}
} << [&] { return nccl->Block(); };
}
[[nodiscard]] Result NCCLColl::Broadcast(Comm const& comm, common::Span<std::int8_t> data,
std::int32_t root) {
if (!comm.IsDistributed()) {
return Success();
}
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
CHECK(nccl);
return Success() << [&] {
return GetNCCLResult(ncclBroadcast(data.data(), data.data(), data.size_bytes(), ncclInt8, root,
nccl->Handle(), nccl->Stream()));
} << [&] { return nccl->Block(); };
}
[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span<std::int8_t> data,
std::int64_t size) {
if (!comm.IsDistributed()) {
return Success();
}
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
CHECK(nccl);
auto send = data.subspan(comm.Rank() * size, size);
return Success() << [&] {
return GetNCCLResult(
ncclAllGather(send.data(), data.data(), size, ncclInt8, nccl->Handle(), nccl->Stream()));
} << [&] { return nccl->Block(); };
}
namespace cuda_impl {
/**
* @brief Implement allgather-v using broadcast.
*
* https://arxiv.org/abs/1812.05964
*/
Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes, common::Span<std::int8_t> recv) {
return Success() << [] { return GetNCCLResult(ncclGroupStart()); } << [&] {
std::size_t offset = 0;
for (std::int32_t r = 0; r < comm->World(); ++r) {
auto as_bytes = sizes[r];
auto rc = ncclBroadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes,
ncclInt8, r, comm->Handle(), dh::DefaultStream());
if (rc != ncclSuccess) {
return GetNCCLResult(rc);
}
offset += as_bytes;
}
return Success();
} << [] { return GetNCCLResult(ncclGroupEnd()); };
}
} // namespace cuda_impl
[[nodiscard]] Result NCCLColl::AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv, AllgatherVAlgo algo) {
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
CHECK(nccl);
if (!comm.IsDistributed()) {
return Success();
}
switch (algo) {
case AllgatherVAlgo::kRing: {
return Success() << [] { return GetNCCLResult(ncclGroupStart()); } << [&] {
// get worker offset
detail::AllgatherVOffset(sizes, recv_segments);
// copy data
auto current = recv.subspan(recv_segments[comm.Rank()], data.size_bytes());
if (current.data() != data.data()) {
dh::safe_cuda(cudaMemcpyAsync(current.data(), data.data(), current.size_bytes(),
cudaMemcpyDeviceToDevice, nccl->Stream()));
}
return detail::RingAllgatherV(comm, sizes, recv_segments, recv);
} << [] {
return GetNCCLResult(ncclGroupEnd());
} << [&] { return nccl->Block(); };
}
case AllgatherVAlgo::kBcast: {
return cuda_impl::BroadcastAllgatherV(nccl, data, sizes, recv);
}
default: {
return Fail("Unknown algorithm for allgather-v");
}
}
}
} // namespace xgboost::collective
#endif // defined(XGBOOST_USE_NCCL)

29
src/collective/coll.cuh Normal file
View File

@ -0,0 +1,29 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int8_t, int64_t
#include "../data/array_interface.h" // for ArrayInterfaceHandler
#include "coll.h" // for Coll
#include "comm.h" // for Comm
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
class NCCLColl : public Coll {
public:
~NCCLColl() override;
[[nodiscard]] Result Allreduce(Comm const& comm, common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type type, Op op) override;
[[nodiscard]] Result Broadcast(Comm const& comm, common::Span<std::int8_t> data,
std::int32_t root) override;
[[nodiscard]] Result Allgather(Comm const& comm, common::Span<std::int8_t> data,
std::int64_t size) override;
[[nodiscard]] Result AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv, AllgatherVAlgo algo) override;
};
} // namespace xgboost::collective

View File

@ -2,17 +2,20 @@
* Copyright 2023, XGBoost Contributors * Copyright 2023, XGBoost Contributors
*/ */
#pragma once #pragma once
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int64_t #include <cstdint> // for int8_t, int64_t
#include <memory> // for enable_shared_from_this #include <memory> // for enable_shared_from_this
#include "../data/array_interface.h" // for ArrayInterfaceHandler #include "../data/array_interface.h" // for ArrayInterfaceHandler
#include "comm.h" // for Comm #include "comm.h" // for Comm
#include "xgboost/collective/result.h" // for Result #include "xgboost/collective/result.h" // for Result
#include "xgboost/context.h" // for Context
#include "xgboost/span.h" // for Span #include "xgboost/span.h" // for Span
namespace xgboost::collective { namespace xgboost::collective {
enum class AllgatherVAlgo {
kRing = 0, // use ring-based allgather-v
kBcast = 1, // use broadcast-based allgather-v
};
/** /**
* @brief Interface and base implementation for collective. * @brief Interface and base implementation for collective.
*/ */
@ -21,6 +24,8 @@ class Coll : public std::enable_shared_from_this<Coll> {
Coll() = default; Coll() = default;
virtual ~Coll() noexcept(false) {} // NOLINT virtual ~Coll() noexcept(false) {} // NOLINT
Coll* MakeCUDAVar();
/** /**
* @brief Allreduce * @brief Allreduce
* *
@ -29,8 +34,7 @@ class Coll : public std::enable_shared_from_this<Coll> {
* @param [in] op Reduce operation. For custom operation, user needs to reach down to * @param [in] op Reduce operation. For custom operation, user needs to reach down to
* the CPU implementation. * the CPU implementation.
*/ */
[[nodiscard]] virtual Result Allreduce(Context const* ctx, Comm const& comm, [[nodiscard]] virtual Result Allreduce(Comm const& comm, common::Span<std::int8_t> data,
common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type type, Op op); ArrayInterfaceHandler::Type type, Op op);
/** /**
* @brief Broadcast * @brief Broadcast
@ -38,29 +42,29 @@ class Coll : public std::enable_shared_from_this<Coll> {
* @param [in,out] data Data buffer for input and output. * @param [in,out] data Data buffer for input and output.
* @param [in] root Root rank for broadcast. * @param [in] root Root rank for broadcast.
*/ */
[[nodiscard]] virtual Result Broadcast(Context const* ctx, Comm const& comm, [[nodiscard]] virtual Result Broadcast(Comm const& comm, common::Span<std::int8_t> data,
common::Span<std::int8_t> data, std::int32_t root); std::int32_t root);
/** /**
* @brief Allgather * @brief Allgather
* *
* @param [in,out] data Data buffer for input and output. * @param [in,out] data Data buffer for input and output.
* @param [in] size Size of data for each worker. * @param [in] size Size of data for each worker.
*/ */
[[nodiscard]] virtual Result Allgather(Context const* ctx, Comm const& comm, [[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span<std::int8_t> data,
common::Span<std::int8_t> data, std::size_t size); std::int64_t size);
/** /**
* @brief Allgather with variable length. * @brief Allgather with variable length.
* *
* @param [in] data Input data for the current worker. * @param [in] data Input data for the current worker.
* @param [in] sizes Size of the input from each worker. * @param [in] sizes Size of the input from each worker.
* @param [out] recv_segments pre-allocated offset for each worker in the output, size * @param [out] recv_segments pre-allocated offset buffer for each worker in the output,
* should be equal to (world + 1). * size should be equal to (world + 1). GPU ring-based implementation
* doesn't use the buffer.
* @param [out] recv pre-allocated buffer for output. * @param [out] recv pre-allocated buffer for output.
*/ */
[[nodiscard]] virtual Result AllgatherV(Context const* ctx, Comm const& comm, [[nodiscard]] virtual Result AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes, common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments, common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv); common::Span<std::int8_t> recv, AllgatherVAlgo algo);
}; };
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -262,7 +262,7 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
} }
RabitComm::~RabitComm() noexcept(false) { RabitComm::~RabitComm() noexcept(false) {
if (!IsDistributed()) { if (!this->IsDistributed()) {
return; return;
} }
auto rc = this->Shutdown(); auto rc = this->Shutdown();

112
src/collective/comm.cu Normal file
View File

@ -0,0 +1,112 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#if defined(XGBOOST_USE_NCCL)
#include <algorithm> // for sort
#include <cstddef> // for size_t
#include <cstdint> // for uint64_t, int8_t
#include <cstring> // for memcpy
#include <memory> // for shared_ptr
#include <sstream> // for stringstream
#include <vector> // for vector
#include "../common/device_helpers.cuh" // for DefaultStream
#include "../common/type.h" // for EraseType
#include "broadcast.h" // for Broadcast
#include "comm.cuh" // for NCCLComm
#include "comm.h" // for Comm
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
namespace {
Result GetUniqueId(Comm const& comm, ncclUniqueId* pid) {
static const int kRootRank = 0;
ncclUniqueId id;
if (comm.Rank() == kRootRank) {
dh::safe_nccl(ncclGetUniqueId(&id));
}
auto rc = Broadcast(comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)},
kRootRank);
if (!rc.OK()) {
return rc;
}
*pid = id;
return Success();
}
inline constexpr std::size_t kUuidLength =
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(std::uint64_t);
void GetCudaUUID(xgboost::common::Span<std::uint64_t, kUuidLength> const& uuid, DeviceOrd device) {
cudaDeviceProp prob{};
dh::safe_cuda(cudaGetDeviceProperties(&prob, device.ordinal));
std::memcpy(uuid.data(), static_cast<void*>(&(prob.uuid)), sizeof(prob.uuid));
}
static std::string PrintUUID(xgboost::common::Span<std::uint64_t, kUuidLength> const& uuid) {
std::stringstream ss;
for (auto v : uuid) {
ss << std::hex << v;
}
return ss.str();
}
} // namespace
Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) {
return new NCCLComm{ctx, *this, pimpl};
}
NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl)
: Comm{root.TrackerInfo().host, root.TrackerInfo().port, root.Timeout(), root.Retry(),
root.TaskID()},
stream_{dh::DefaultStream()} {
this->world_ = root.World();
this->rank_ = root.Rank();
this->domain_ = root.Domain();
if (!root.IsDistributed()) {
return;
}
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
std::vector<std::uint64_t> uuids(root.World() * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<std::uint64_t>{uuids.data(), uuids.size()};
auto s_this_uuid = s_uuid.subspan(root.Rank() * kUuidLength, kUuidLength);
GetCudaUUID(s_this_uuid, ctx->Device());
auto rc = pimpl->Allgather(root, common::EraseType(s_uuid), s_this_uuid.size_bytes());
CHECK(rc.OK()) << rc.Report();
std::vector<xgboost::common::Span<std::uint64_t, kUuidLength>> converted(root.World());
std::size_t j = 0;
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
converted[j] = s_uuid.subspan(i, kUuidLength);
j++;
}
std::sort(converted.begin(), converted.end());
auto iter = std::unique(converted.begin(), converted.end());
auto n_uniques = std::distance(converted.begin(), iter);
CHECK_EQ(n_uniques, root.World())
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
rc = GetUniqueId(root, &nccl_unique_id_);
CHECK(rc.OK()) << rc.Report();
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank()));
for (std::int32_t r = 0; r < root.World(); ++r) {
this->channels_.emplace_back(
std::make_shared<NCCLChannel>(root, r, nccl_comm_, dh::DefaultStream()));
}
}
NCCLComm::~NCCLComm() {
if (nccl_comm_) {
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
}
}
} // namespace xgboost::collective
#endif // defined(XGBOOST_USE_NCCL)

67
src/collective/comm.cuh Normal file
View File

@ -0,0 +1,67 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#ifdef XGBOOST_USE_NCCL
#include "nccl.h"
#endif // XGBOOST_USE_NCCL
#include "../common/device_helpers.cuh"
#include "coll.h"
#include "comm.h"
#include "xgboost/context.h"
namespace xgboost::collective {
inline Result GetCUDAResult(cudaError rc) {
if (rc == cudaSuccess) {
return Success();
}
std::string msg = thrust::system_error(rc, thrust::cuda_category()).what();
return Fail(msg);
}
class NCCLComm : public Comm {
ncclComm_t nccl_comm_{nullptr};
ncclUniqueId nccl_unique_id_{};
dh::CUDAStreamView stream_;
public:
[[nodiscard]] ncclComm_t Handle() const { return nccl_comm_; }
explicit NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl);
[[nodiscard]] Result LogTracker(std::string) const override {
LOG(FATAL) << "Device comm is used for logging.";
return Fail("Undefined.");
}
~NCCLComm() override;
[[nodiscard]] bool IsFederated() const override { return false; }
[[nodiscard]] dh::CUDAStreamView Stream() const { return stream_; }
[[nodiscard]] Result Block() const override {
auto rc = this->Stream().Sync(false);
return GetCUDAResult(rc);
}
};
class NCCLChannel : public Channel {
std::int32_t rank_{-1};
ncclComm_t nccl_comm_{};
dh::CUDAStreamView stream_;
public:
explicit NCCLChannel(Comm const& comm, std::int32_t rank, ncclComm_t nccl_comm,
dh::CUDAStreamView stream)
: rank_{rank}, nccl_comm_{nccl_comm}, Channel{comm, nullptr}, stream_{stream} {}
void SendAll(std::int8_t const* ptr, std::size_t n) override {
dh::safe_nccl(ncclSend(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
}
void RecvAll(std::int8_t* ptr, std::size_t n) override {
dh::safe_nccl(ncclRecv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
}
[[nodiscard]] Result Block() override {
auto rc = stream_.Sync(false);
return GetCUDAResult(rc);
}
};
} // namespace xgboost::collective

View File

@ -2,20 +2,20 @@
* Copyright 2023, XGBoost Contributors * Copyright 2023, XGBoost Contributors
*/ */
#pragma once #pragma once
#include <chrono> // for seconds #include <chrono> // for seconds
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include <cstdint> // for int32_t #include <cstdint> // for int32_t
#include <memory> // for shared_ptr #include <memory> // for shared_ptr
#include <string> // for string #include <string> // for string
#include <thread> // for thread #include <thread> // for thread
#include <type_traits> // for remove_const_t #include <utility> // for move
#include <utility> // for move #include <vector> // for vector
#include <vector> // for vector
#include "loop.h" // for Loop #include "loop.h" // for Loop
#include "protocol.h" // for PeerInfo #include "protocol.h" // for PeerInfo
#include "xgboost/collective/result.h" // for Result #include "xgboost/collective/result.h" // for Result
#include "xgboost/collective/socket.h" // for TCPSocket #include "xgboost/collective/socket.h" // for TCPSocket
#include "xgboost/context.h" // for Context
#include "xgboost/span.h" // for Span #include "xgboost/span.h" // for Span
namespace xgboost::collective { namespace xgboost::collective {
@ -35,13 +35,14 @@ inline std::int32_t BootstrapPrev(std::int32_t r, std::int32_t world) {
} }
class Channel; class Channel;
class Coll;
/** /**
* @brief Base communicator storing info about the tracker and other communicators. * @brief Base communicator storing info about the tracker and other communicators.
*/ */
class Comm { class Comm {
protected: protected:
std::int32_t world_{1}; std::int32_t world_{-1};
std::int32_t rank_{0}; std::int32_t rank_{0};
std::chrono::seconds timeout_{DefaultTimeoutSec()}; std::chrono::seconds timeout_{DefaultTimeoutSec()};
std::int32_t retry_{DefaultRetry()}; std::int32_t retry_{DefaultRetry()};
@ -69,12 +70,14 @@ class Comm {
[[nodiscard]] Result ConnectTracker(TCPSocket* out) const; [[nodiscard]] Result ConnectTracker(TCPSocket* out) const;
[[nodiscard]] auto Domain() const { return domain_; } [[nodiscard]] auto Domain() const { return domain_; }
[[nodiscard]] auto Timeout() const { return timeout_; } [[nodiscard]] auto Timeout() const { return timeout_; }
[[nodiscard]] auto Retry() const { return retry_; }
[[nodiscard]] auto TaskID() const { return task_id_; }
[[nodiscard]] auto Rank() const { return rank_; } [[nodiscard]] auto Rank() const { return rank_; }
[[nodiscard]] auto World() const { return world_; } [[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; }
[[nodiscard]] bool IsDistributed() const { return World() > 1; } [[nodiscard]] bool IsDistributed() const { return world_ != -1; }
void Submit(Loop::Op op) const { loop_->Submit(op); } void Submit(Loop::Op op) const { loop_->Submit(op); }
[[nodiscard]] Result Block() const { return loop_->Block(); } [[nodiscard]] virtual Result Block() const { return loop_->Block(); }
[[nodiscard]] virtual std::shared_ptr<Channel> Chan(std::int32_t rank) const { [[nodiscard]] virtual std::shared_ptr<Channel> Chan(std::int32_t rank) const {
return channels_.at(rank); return channels_.at(rank);
@ -83,6 +86,8 @@ class Comm {
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0; [[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); } [[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl);
}; };
class RabitComm : public Comm { class RabitComm : public Comm {
@ -116,7 +121,7 @@ class Channel {
explicit Channel(Comm const& comm, std::shared_ptr<TCPSocket> sock) explicit Channel(Comm const& comm, std::shared_ptr<TCPSocket> sock)
: sock_{std::move(sock)}, comm_{comm} {} : sock_{std::move(sock)}, comm_{comm} {}
void SendAll(std::int8_t const* ptr, std::size_t n) { virtual void SendAll(std::int8_t const* ptr, std::size_t n) {
Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast<std::int8_t*>(ptr), n, sock_.get(), 0}; Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast<std::int8_t*>(ptr), n, sock_.get(), 0};
CHECK(sock_.get()); CHECK(sock_.get());
comm_.Submit(std::move(op)); comm_.Submit(std::move(op));
@ -125,7 +130,7 @@ class Channel {
this->SendAll(data.data(), data.size_bytes()); this->SendAll(data.data(), data.size_bytes());
} }
void RecvAll(std::int8_t* ptr, std::size_t n) { virtual void RecvAll(std::int8_t* ptr, std::size_t n) {
Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0}; Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0};
CHECK(sock_.get()); CHECK(sock_.get());
comm_.Submit(std::move(op)); comm_.Submit(std::move(op));
@ -133,7 +138,7 @@ class Channel {
void RecvAll(common::Span<std::int8_t> data) { this->RecvAll(data.data(), data.size_bytes()); } void RecvAll(common::Span<std::int8_t> data) { this->RecvAll(data.data(), data.size_bytes()); }
[[nodiscard]] auto Socket() const { return sock_; } [[nodiscard]] auto Socket() const { return sock_; }
[[nodiscard]] Result Block() { return comm_.Block(); } [[nodiscard]] virtual Result Block() { return comm_.Block(); }
}; };
enum class Op { kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5 }; enum class Op { kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5 };

View File

@ -1169,7 +1169,13 @@ class CUDAStreamView {
operator cudaStream_t() const { // NOLINT operator cudaStream_t() const { // NOLINT
return stream_; return stream_;
} }
void Sync() { dh::safe_cuda(cudaStreamSynchronize(stream_)); } cudaError_t Sync(bool error = true) {
if (error) {
dh::safe_cuda(cudaStreamSynchronize(stream_));
return cudaSuccess;
}
return cudaStreamSynchronize(stream_);
}
}; };
inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT

View File

@ -19,7 +19,6 @@
#include "../common/cuda_context.cuh" // CUDAContext #include "../common/cuda_context.cuh" // CUDAContext
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
#include "../common/hist_util.h" #include "../common/hist_util.h"
#include "../common/io.h"
#include "../common/timer.h" #include "../common/timer.h"
#include "../data/ellpack_page.cuh" #include "../data/ellpack_page.cuh"
#include "../data/ellpack_page.h" #include "../data/ellpack_page.h"
@ -39,7 +38,6 @@
#include "xgboost/data.h" #include "xgboost/data.h"
#include "xgboost/host_device_vector.h" #include "xgboost/host_device_vector.h"
#include "xgboost/json.h" #include "xgboost/json.h"
#include "xgboost/parameter.h"
#include "xgboost/span.h" #include "xgboost/span.h"
#include "xgboost/task.h" // for ObjInfo #include "xgboost/task.h" // for ObjInfo
#include "xgboost/tree_model.h" #include "xgboost/tree_model.h"

View File

@ -14,6 +14,7 @@
#include <vector> // for vector #include <vector> // for vector
#include "../../../src/collective/allgather.h" // for RingAllgather #include "../../../src/collective/allgather.h" // for RingAllgather
#include "../../../src/collective/coll.h" // for Coll
#include "../../../src/collective/comm.h" // for RabitComm #include "../../../src/collective/comm.h" // for RabitComm
#include "gtest/gtest.h" // for AssertionR... #include "gtest/gtest.h" // for AssertionR...
#include "test_worker.h" // for TestDistri... #include "test_worker.h" // for TestDistri...
@ -63,37 +64,79 @@ class Worker : public WorkerForTest {
} }
} }
void TestV() { void CheckV(common::Span<std::int32_t> result) {
{ std::int32_t k{0};
// basic test for (std::int32_t r = 0; r < comm_.World(); ++r) {
std::int32_t n{comm_.Rank()}; auto seg = common::Span{result.data(), result.size()}.subspan(k, (r + 1));
std::vector<std::int32_t> result; if (comm_.Rank() == 0) {
auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result); for (auto v : seg) {
ASSERT_TRUE(rc.OK()) << rc.Report(); ASSERT_EQ(v, r);
for (std::int32_t i = 0; i < comm_.World(); ++i) {
ASSERT_EQ(result[i], i);
}
}
{
// V test
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
std::vector<std::int32_t> result;
auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result);
ASSERT_TRUE(rc.OK()) << rc.Report();
ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2);
std::int32_t k{0};
for (std::int32_t r = 0; r < comm_.World(); ++r) {
auto seg = common::Span{result.data(), result.size()}.subspan(k, (r + 1));
if (comm_.Rank() == 0) {
for (auto v : seg) {
ASSERT_EQ(v, r);
}
k += seg.size();
} }
k += seg.size();
} }
} }
} }
void TestVRing() {
// V test
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
std::vector<std::int32_t> result;
auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result);
ASSERT_TRUE(rc.OK()) << rc.Report();
ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2);
CheckV(result);
}
void TestVBasic() {
// basic test
std::int32_t n{comm_.Rank()};
std::vector<std::int32_t> result;
auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result);
ASSERT_TRUE(rc.OK()) << rc.Report();
for (std::int32_t i = 0; i < comm_.World(); ++i) {
ASSERT_EQ(result[i], i);
}
}
void TestVAlgo() {
// V test, broadcast
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
auto s_data = common::Span{data.data(), data.size()};
std::vector<std::int64_t> sizes(comm_.World(), 0);
sizes[comm_.Rank()] = s_data.size_bytes();
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
ASSERT_TRUE(rc.OK()) << rc.Report();
std::shared_ptr<Coll> pcoll{new Coll{}};
std::vector<std::int64_t> recv_segments(comm_.World() + 1, 0);
std::vector<std::int32_t> recv(std::accumulate(sizes.cbegin(), sizes.cend(), 0));
auto s_recv = common::Span{recv.data(), recv.size()};
rc = pcoll->AllgatherV(comm_, common::EraseType(s_data),
common::Span{sizes.data(), sizes.size()},
common::Span{recv_segments.data(), recv_segments.size()},
common::EraseType(s_recv), AllgatherVAlgo::kBcast);
ASSERT_TRUE(rc.OK());
CheckV(s_recv);
// Test inplace
auto test_inplace = [&] (AllgatherVAlgo algo) {
std::fill_n(s_recv.data(), s_recv.size(), 0);
auto current = s_recv.subspan(recv_segments[comm_.Rank()],
recv_segments[comm_.Rank() + 1] - recv_segments[comm_.Rank()]);
std::copy_n(data.data(), data.size(), current.data());
rc = pcoll->AllgatherV(comm_, common::EraseType(current),
common::Span{sizes.data(), sizes.size()},
common::Span{recv_segments.data(), recv_segments.size()},
common::EraseType(s_recv), algo);
ASSERT_TRUE(rc.OK());
CheckV(s_recv);
};
test_inplace(AllgatherVAlgo::kBcast);
test_inplace(AllgatherVAlgo::kRing);
}
}; };
} // namespace } // namespace
@ -106,12 +149,30 @@ TEST_F(AllgatherTest, Basic) {
}); });
} }
TEST_F(AllgatherTest, V) { TEST_F(AllgatherTest, VBasic) {
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency()); std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) { std::int32_t r) {
Worker worker{host, port, timeout, n_workers, r}; Worker worker{host, port, timeout, n_workers, r};
worker.TestV(); worker.TestVBasic();
});
}
TEST_F(AllgatherTest, VRing) {
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Worker worker{host, port, timeout, n_workers, r};
worker.TestVRing();
});
}
TEST_F(AllgatherTest, VAlgo) {
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Worker worker{host, port, timeout, n_workers, r};
worker.TestVAlgo();
}); });
} }
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -0,0 +1,117 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#if defined(XGBOOST_USE_NCCL)
#include <gtest/gtest.h>
#include <thrust/device_vector.h> // for device_vector
#include <thrust/equal.h> // for equal
#include <xgboost/span.h> // for Span
#include <cstddef> // for size_t
#include <cstdint> // for int32_t, int64_t
#include <vector> // for vector
#include "../../../src/collective/allgather.h" // for RingAllgather
#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector
#include "../../../src/common/type.h" // for EraseType
#include "test_worker.cuh" // for NCCLWorkerForTest
#include "test_worker.h" // for TestDistributed, WorkerForTest
namespace xgboost::collective {
namespace {
class Worker : public NCCLWorkerForTest {
public:
using NCCLWorkerForTest::NCCLWorkerForTest;
void TestV(AllgatherVAlgo algo) {
{
// basic test
std::size_t n = 1;
// create data
dh::device_vector<std::int32_t> data(n, comm_.Rank());
auto s_data = common::EraseType(common::Span{data.data().get(), data.size()});
// get size
std::vector<std::int64_t> sizes(comm_.World(), -1);
sizes[comm_.Rank()] = s_data.size_bytes();
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
ASSERT_TRUE(rc.OK()) << rc.Report();
// create result
dh::device_vector<std::int32_t> result(comm_.World(), -1);
auto s_result = common::EraseType(dh::ToSpan(result));
std::vector<std::int64_t> recv_seg(nccl_comm_->World() + 1, 0);
rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()},
common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo);
ASSERT_TRUE(rc.OK()) << rc.Report();
for (std::int32_t i = 0; i < comm_.World(); ++i) {
ASSERT_EQ(result[i], i);
}
}
{
// V test
std::size_t n = 256 * 256;
// create data
dh::device_vector<std::int32_t> data(n * nccl_comm_->Rank(), nccl_comm_->Rank());
auto s_data = common::EraseType(common::Span{data.data().get(), data.size()});
// get size
std::vector<std::int64_t> sizes(nccl_comm_->World(), 0);
sizes[comm_.Rank()] = dh::ToSpan(data).size_bytes();
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
ASSERT_TRUE(rc.OK()) << rc.Report();
auto n_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
// create result
dh::device_vector<std::int32_t> result(n_bytes / sizeof(std::int32_t), -1);
auto s_result = common::EraseType(dh::ToSpan(result));
std::vector<std::int64_t> recv_seg(nccl_comm_->World() + 1, 0);
rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()},
common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo);
ASSERT_TRUE(rc.OK()) << rc.Report();
// check segment size
if (algo != AllgatherVAlgo::kBcast) {
auto size = recv_seg[nccl_comm_->Rank() + 1] - recv_seg[nccl_comm_->Rank()];
ASSERT_EQ(size, n * nccl_comm_->Rank() * sizeof(std::int32_t));
ASSERT_EQ(size, sizes[nccl_comm_->Rank()]);
}
// check data
std::size_t k{0};
for (std::int32_t r = 0; r < nccl_comm_->World(); ++r) {
std::size_t s = n * r;
auto current = dh::ToSpan(result).subspan(k, s);
std::vector<std::int32_t> h_data(current.size());
dh::CopyDeviceSpanToVector(&h_data, current);
for (auto v : h_data) {
ASSERT_EQ(v, r);
}
k += s;
}
}
}
};
class AllgatherTestGPU : public SocketTest {};
} // namespace
TEST_F(AllgatherTestGPU, MGPUTestVRing) {
auto n_workers = common::AllVisibleGPUs();
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Worker w{host, port, timeout, n_workers, r};
w.Setup();
w.TestV(AllgatherVAlgo::kRing);
w.TestV(AllgatherVAlgo::kBcast);
});
}
TEST_F(AllgatherTestGPU, MGPUTestVBcast) {
auto n_workers = common::AllVisibleGPUs();
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Worker w{host, port, timeout, n_workers, r};
w.Setup();
w.TestV(AllgatherVAlgo::kBcast);
});
}
} // namespace xgboost::collective
#endif // defined(XGBOOST_USE_NCCL)

View File

@ -6,10 +6,10 @@
#include "../../../src/collective/allreduce.h" #include "../../../src/collective/allreduce.h"
#include "../../../src/collective/coll.h" // for Coll #include "../../../src/collective/coll.h" // for Coll
#include "../../../src/collective/tracker.h" #include "../../../src/collective/tracker.h"
#include "test_worker.h" // for WorkerForTest, TestDistributed #include "../../../src/common/type.h" // for EraseType
#include "test_worker.h" // for WorkerForTest, TestDistributed
namespace xgboost::collective { namespace xgboost::collective {
namespace { namespace {
class AllreduceWorker : public WorkerForTest { class AllreduceWorker : public WorkerForTest {
public: public:
@ -50,11 +50,10 @@ class AllreduceWorker : public WorkerForTest {
} }
void BitOr() { void BitOr() {
Context ctx;
std::vector<std::uint32_t> data(comm_.World(), 0); std::vector<std::uint32_t> data(comm_.World(), 0);
data[comm_.Rank()] = ~std::uint32_t{0}; data[comm_.Rank()] = ~std::uint32_t{0};
auto pcoll = std::shared_ptr<Coll>{new Coll{}}; auto pcoll = std::shared_ptr<Coll>{new Coll{}};
auto rc = pcoll->Allreduce(&ctx, comm_, EraseType(common::Span{data.data(), data.size()}), auto rc = pcoll->Allreduce(comm_, common::EraseType(common::Span{data.data(), data.size()}),
ArrayInterfaceHandler::kU4, Op::kBitwiseOR); ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
ASSERT_TRUE(rc.OK()) << rc.Report(); ASSERT_TRUE(rc.OK()) << rc.Report();
for (auto v : data) { for (auto v : data) {

View File

@ -0,0 +1,70 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#if defined(XGBOOST_USE_NCCL)
#include <gtest/gtest.h>
#include <thrust/host_vector.h> // for host_vector
#include "../../../src/collective/coll.h" // for Coll
#include "../../../src/common/common.h"
#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector
#include "../../../src/common/type.h" // for EraseType
#include "../helpers.h" // for MakeCUDACtx
#include "test_worker.cuh" // for NCCLWorkerForTest
#include "test_worker.h" // for WorkerForTest, TestDistributed
namespace xgboost::collective {
namespace {
class AllreduceTestGPU : public SocketTest {};
class Worker : public NCCLWorkerForTest {
public:
using NCCLWorkerForTest::NCCLWorkerForTest;
void BitOr() {
dh::device_vector<std::uint32_t> data(comm_.World(), 0);
data[comm_.Rank()] = ~std::uint32_t{0};
auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)),
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
ASSERT_TRUE(rc.OK()) << rc.Report();
thrust::host_vector<std::uint32_t> h_data(data.size());
thrust::copy(data.cbegin(), data.cend(), h_data.begin());
for (auto v : h_data) {
ASSERT_EQ(v, ~std::uint32_t{0});
}
}
void Acc() {
dh::device_vector<double> data(314, 1.5);
auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)),
ArrayInterfaceHandler::kF8, Op::kSum);
ASSERT_TRUE(rc.OK()) << rc.Report();
for (std::size_t i = 0; i < data.size(); ++i) {
auto v = data[i];
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
}
}
};
} // namespace
TEST_F(AllreduceTestGPU, BitOr) {
auto n_workers = common::AllVisibleGPUs();
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Worker w{host, port, timeout, n_workers, r};
w.Setup();
w.BitOr();
});
}
TEST_F(AllreduceTestGPU, Sum) {
auto n_workers = common::AllVisibleGPUs();
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Worker w{host, port, timeout, n_workers, r};
w.Setup();
w.Acc();
});
}
} // namespace xgboost::collective
#endif // defined(XGBOOST_USE_NCCL)

View File

@ -47,5 +47,5 @@ TEST_F(BroadcastTest, Basic) {
Worker worker{host, port, timeout, n_workers, r}; Worker worker{host, port, timeout, n_workers, r};
worker.Run(); worker.Run();
}); });
} } // namespace
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -0,0 +1,32 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <memory> // for shared_ptr
#include "../../../src/collective/coll.h" // for Coll
#include "../../../src/collective/comm.h" // for Comm
#include "test_worker.h"
#include "xgboost/context.h" // for Context
namespace xgboost::collective {
class NCCLWorkerForTest : public WorkerForTest {
protected:
std::shared_ptr<Coll> coll_;
std::shared_ptr<xgboost::collective::Comm> nccl_comm_;
std::shared_ptr<Coll> nccl_coll_;
Context ctx_;
public:
using WorkerForTest::WorkerForTest;
void Setup() {
ctx_ = MakeCUDACtx(comm_.Rank());
coll_.reset(new Coll{});
nccl_comm_.reset(this->comm_.MakeCUDAVar(&ctx_, coll_));
nccl_coll_.reset(coll_->MakeCUDAVar());
ASSERT_EQ(comm_.World(), nccl_comm_->World());
ASSERT_EQ(comm_.Rank(), nccl_comm_->Rank());
}
};
} // namespace xgboost::collective

View File

@ -1,6 +1,7 @@
/** /**
* Copyright 2023, XGBoost Contributors * Copyright 2023, XGBoost Contributors
*/ */
#pragma once
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <chrono> // for seconds #include <chrono> // for seconds