[coll] Add nccl. (#9726)
This commit is contained in:
parent
0c621094b3
commit
6755179e77
@ -7,20 +7,23 @@
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int8_t, int32_t, int64_t
|
||||
#include <memory> // for shared_ptr
|
||||
#include <numeric> // for partial_sum
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "broadcast.h"
|
||||
#include "comm.h" // for Comm, Channel
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective::cpu_impl {
|
||||
namespace xgboost::collective {
|
||||
namespace cpu_impl {
|
||||
Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size_t segment_size,
|
||||
std::int32_t worker_off, std::shared_ptr<Channel> prev_ch,
|
||||
std::shared_ptr<Channel> next_ch) {
|
||||
auto world = comm.World();
|
||||
auto rank = comm.Rank();
|
||||
CHECK_LT(worker_off, world);
|
||||
if (world == 1) {
|
||||
return Success();
|
||||
}
|
||||
|
||||
for (std::int32_t r = 0; r < world; ++r) {
|
||||
auto send_rank = (rank + world - r + worker_off) % world;
|
||||
@ -43,11 +46,29 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
||||
return Success();
|
||||
}
|
||||
|
||||
Result BroadcastAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int8_t> recv) {
|
||||
std::size_t offset = 0;
|
||||
for (std::int32_t r = 0; r < comm.World(); ++r) {
|
||||
auto as_bytes = sizes[r];
|
||||
auto rc = Broadcast(comm, recv.subspan(offset, as_bytes), r);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
offset += as_bytes;
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
} // namespace cpu_impl
|
||||
|
||||
namespace detail {
|
||||
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t> offset,
|
||||
common::Span<std::int64_t const> offset,
|
||||
common::Span<std::int8_t> erased_result) {
|
||||
auto world = comm.World();
|
||||
if (world == 1) {
|
||||
return Success();
|
||||
}
|
||||
auto rank = comm.Rank();
|
||||
|
||||
auto prev = BootstrapPrev(rank, comm.World());
|
||||
@ -56,17 +77,6 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
||||
auto prev_ch = comm.Chan(prev);
|
||||
auto next_ch = comm.Chan(next);
|
||||
|
||||
// get worker offset
|
||||
CHECK_EQ(world + 1, offset.size());
|
||||
std::fill_n(offset.data(), offset.size(), 0);
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
|
||||
CHECK_EQ(*offset.cbegin(), 0);
|
||||
|
||||
// copy data
|
||||
auto current = erased_result.subspan(offset[rank], data.size_bytes());
|
||||
auto erased_data = EraseType(data);
|
||||
std::copy_n(erased_data.data(), erased_data.size(), current.data());
|
||||
|
||||
for (std::int32_t r = 0; r < world; ++r) {
|
||||
auto send_rank = (rank + world - r) % world;
|
||||
auto send_off = offset[send_rank];
|
||||
@ -87,4 +97,5 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
||||
}
|
||||
return comm.Block();
|
||||
}
|
||||
} // namespace xgboost::collective::cpu_impl
|
||||
} // namespace detail
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -12,25 +12,44 @@
|
||||
#include "../common/type.h" // for EraseType
|
||||
#include "comm.h" // for Comm, Channel
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace cpu_impl {
|
||||
/**
|
||||
* @param worker_off Segment offset. For example, if the rank 2 worker specifis worker_off
|
||||
* = 1, then it owns the third segment.
|
||||
* @param worker_off Segment offset. For example, if the rank 2 worker specifies
|
||||
* worker_off = 1, then it owns the third segment.
|
||||
*/
|
||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::size_t segment_size, std::int32_t worker_off,
|
||||
std::shared_ptr<Channel> prev_ch,
|
||||
std::shared_ptr<Channel> next_ch);
|
||||
|
||||
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t> offset,
|
||||
common::Span<std::int8_t> erased_result);
|
||||
/**
|
||||
* @brief Implement allgather-v using broadcast.
|
||||
*
|
||||
* https://arxiv.org/abs/1812.05964
|
||||
*/
|
||||
Result BroadcastAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int8_t> recv);
|
||||
} // namespace cpu_impl
|
||||
|
||||
namespace detail {
|
||||
inline void AllgatherVOffset(common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> offset) {
|
||||
// get worker offset
|
||||
std::fill_n(offset.data(), offset.size(), 0);
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
|
||||
CHECK_EQ(*offset.cbegin(), 0);
|
||||
}
|
||||
|
||||
// An implementation that's used by both cpu and gpu
|
||||
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t const> offset,
|
||||
common::Span<std::int8_t> erased_result);
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
|
||||
auto n_bytes = sizeof(T) * size;
|
||||
@ -68,9 +87,15 @@ template <typename T>
|
||||
auto h_result = common::Span{result.data(), result.size()};
|
||||
auto erased_result = common::EraseType(h_result);
|
||||
auto erased_data = common::EraseType(data);
|
||||
std::vector<std::int64_t> offset(world + 1);
|
||||
std::vector<std::int64_t> recv_segments(world + 1);
|
||||
auto s_segments = common::Span{recv_segments.data(), recv_segments.size()};
|
||||
|
||||
return cpu_impl::RingAllgatherV(comm, sizes, erased_data,
|
||||
common::Span{offset.data(), offset.size()}, erased_result);
|
||||
// get worker offset
|
||||
detail::AllgatherVOffset(sizes, s_segments);
|
||||
// copy data
|
||||
auto current = erased_result.subspan(recv_segments[rank], data.size_bytes());
|
||||
std::copy_n(erased_data.data(), erased_data.size(), current.data());
|
||||
|
||||
return detail::RingAllgatherV(comm, sizes, s_segments, erased_result);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -12,12 +12,10 @@
|
||||
#include "allreduce.h" // for Allreduce
|
||||
#include "broadcast.h" // for Broadcast
|
||||
#include "comm.h" // for Comm
|
||||
#include "xgboost/context.h" // for Context
|
||||
|
||||
namespace xgboost::collective {
|
||||
[[nodiscard]] Result Coll::Allreduce(Context const*, Comm const& comm,
|
||||
common::Span<std::int8_t> data, ArrayInterfaceHandler::Type,
|
||||
Op op) {
|
||||
[[nodiscard]] Result Coll::Allreduce(Comm const& comm, common::Span<std::int8_t> data,
|
||||
ArrayInterfaceHandler::Type, Op op) {
|
||||
namespace coll = ::xgboost::collective;
|
||||
|
||||
auto redop_fn = [](auto lhs, auto out, auto elem_op) {
|
||||
@ -55,21 +53,45 @@ namespace xgboost::collective {
|
||||
return comm.Block();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Coll::Broadcast(Context const*, Comm const& comm,
|
||||
common::Span<std::int8_t> data, std::int32_t root) {
|
||||
[[nodiscard]] Result Coll::Broadcast(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) {
|
||||
return cpu_impl::Broadcast(comm, data, root);
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Coll::Allgather(Context const*, Comm const& comm,
|
||||
common::Span<std::int8_t> data, std::size_t size) {
|
||||
[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size) {
|
||||
return RingAllgather(comm, data, size);
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Coll::AllgatherV(Context const*, Comm const& comm,
|
||||
common::Span<std::int8_t const> data,
|
||||
[[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments,
|
||||
common::Span<std::int8_t> recv) {
|
||||
return cpu_impl::RingAllgatherV(comm, sizes, data, recv_segments, recv);
|
||||
common::Span<std::int8_t> recv, AllgatherVAlgo algo) {
|
||||
// get worker offset
|
||||
detail::AllgatherVOffset(sizes, recv_segments);
|
||||
|
||||
// copy data
|
||||
auto current = recv.subspan(recv_segments[comm.Rank()], data.size_bytes());
|
||||
if (current.data() != data.data()) {
|
||||
std::copy_n(data.data(), data.size(), current.data());
|
||||
}
|
||||
|
||||
switch (algo) {
|
||||
case AllgatherVAlgo::kRing:
|
||||
return detail::RingAllgatherV(comm, sizes, recv_segments, recv);
|
||||
case AllgatherVAlgo::kBcast:
|
||||
return cpu_impl::BroadcastAllgatherV(comm, sizes, recv);
|
||||
default: {
|
||||
return Fail("Unknown algorithm for allgather-v");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_NCCL)
|
||||
Coll* Coll::MakeCUDAVar() {
|
||||
LOG(FATAL) << "NCCL is required for device communication.";
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace xgboost::collective
|
||||
|
||||
254
src/collective/coll.cu
Normal file
254
src/collective/coll.cu
Normal file
@ -0,0 +1,254 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
#include <cstdint> // for int8_t, int64_t
|
||||
|
||||
#include "../common/cuda_context.cuh"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../data/array_interface.h"
|
||||
#include "allgather.h" // for AllgatherVOffset
|
||||
#include "coll.cuh"
|
||||
#include "comm.cuh"
|
||||
#include "nccl.h"
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
Coll* Coll::MakeCUDAVar() { return new NCCLColl{}; }
|
||||
|
||||
NCCLColl::~NCCLColl() = default;
|
||||
namespace {
|
||||
Result GetNCCLResult(ncclResult_t code) {
|
||||
if (code == ncclSuccess) {
|
||||
return Success();
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "NCCL failure: " << ncclGetErrorString(code) << ".";
|
||||
if (code == ncclUnhandledCudaError) {
|
||||
// nccl usually preserves the last error so we can get more details.
|
||||
auto err = cudaPeekAtLastError();
|
||||
ss << " CUDA error: " << thrust::system_error(err, thrust::cuda_category()).what() << "\n";
|
||||
} else if (code == ncclSystemError) {
|
||||
ss << " This might be caused by a network configuration issue. Please consider specifying "
|
||||
"the network interface for NCCL via environment variables listed in its reference: "
|
||||
"`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n";
|
||||
}
|
||||
return Fail(ss.str());
|
||||
}
|
||||
|
||||
auto GetNCCLType(ArrayInterfaceHandler::Type type) {
|
||||
auto fatal = [] {
|
||||
LOG(FATAL) << "Invalid type for NCCL operation.";
|
||||
return ncclHalf; // dummy return to silent the compiler warning.
|
||||
};
|
||||
using H = ArrayInterfaceHandler;
|
||||
switch (type) {
|
||||
case H::kF2:
|
||||
return ncclHalf;
|
||||
case H::kF4:
|
||||
return ncclFloat32;
|
||||
case H::kF8:
|
||||
return ncclFloat64;
|
||||
case H::kF16:
|
||||
return fatal();
|
||||
case H::kI1:
|
||||
return ncclInt8;
|
||||
case H::kI2:
|
||||
return fatal();
|
||||
case H::kI4:
|
||||
return ncclInt32;
|
||||
case H::kI8:
|
||||
return ncclInt64;
|
||||
case H::kU1:
|
||||
return ncclUint8;
|
||||
case H::kU2:
|
||||
return fatal();
|
||||
case H::kU4:
|
||||
return ncclUint32;
|
||||
case H::kU8:
|
||||
return ncclUint64;
|
||||
}
|
||||
return fatal();
|
||||
}
|
||||
|
||||
bool IsBitwiseOp(Op const& op) {
|
||||
return op == Op::kBitwiseAND || op == Op::kBitwiseOR || op == Op::kBitwiseXOR;
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
void RunBitwiseAllreduce(dh::CUDAStreamView stream, common::Span<std::int8_t> out_buffer,
|
||||
std::int8_t const* device_buffer, Func func, std::int32_t world_size,
|
||||
std::size_t size) {
|
||||
dh::LaunchN(size, stream, [=] __device__(std::size_t idx) {
|
||||
auto result = device_buffer[idx];
|
||||
for (auto rank = 1; rank < world_size; rank++) {
|
||||
result = func(result, device_buffer[rank * size + idx]);
|
||||
}
|
||||
out_buffer[idx] = result;
|
||||
});
|
||||
}
|
||||
|
||||
[[nodiscard]] Result BitwiseAllReduce(NCCLComm const* pcomm, ncclComm_t handle,
|
||||
common::Span<std::int8_t> data, Op op) {
|
||||
dh::device_vector<std::int8_t> buffer(data.size() * pcomm->World());
|
||||
auto* device_buffer = buffer.data().get();
|
||||
|
||||
// First gather data from all the workers.
|
||||
CHECK(handle);
|
||||
auto rc = GetNCCLResult(
|
||||
ncclAllGather(data.data(), device_buffer, data.size(), ncclInt8, handle, pcomm->Stream()));
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
// Then reduce locally.
|
||||
switch (op) {
|
||||
case Op::kBitwiseAND:
|
||||
RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_and<std::int8_t>(),
|
||||
pcomm->World(), data.size());
|
||||
break;
|
||||
case Op::kBitwiseOR:
|
||||
RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_or<std::int8_t>(),
|
||||
pcomm->World(), data.size());
|
||||
break;
|
||||
case Op::kBitwiseXOR:
|
||||
RunBitwiseAllreduce(pcomm->Stream(), data, device_buffer, thrust::bit_xor<std::int8_t>(),
|
||||
pcomm->World(), data.size());
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Not a bitwise reduce operation.";
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
|
||||
ncclRedOp_t GetNCCLRedOp(Op const& op) {
|
||||
ncclRedOp_t result{ncclMax};
|
||||
switch (op) {
|
||||
case Op::kMax:
|
||||
result = ncclMax;
|
||||
break;
|
||||
case Op::kMin:
|
||||
result = ncclMin;
|
||||
break;
|
||||
case Op::kSum:
|
||||
result = ncclSum;
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported reduce operation.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
[[nodiscard]] Result NCCLColl::Allreduce(Comm const& comm, common::Span<std::int8_t> data,
|
||||
ArrayInterfaceHandler::Type type, Op op) {
|
||||
if (!comm.IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||
CHECK(nccl);
|
||||
return Success() << [&] {
|
||||
if (IsBitwiseOp(op)) {
|
||||
return BitwiseAllReduce(nccl, nccl->Handle(), data, op);
|
||||
} else {
|
||||
return DispatchDType(type, [=](auto t) {
|
||||
using T = decltype(t);
|
||||
auto rdata = common::RestoreType<T>(data);
|
||||
auto rc = ncclAllReduce(data.data(), data.data(), rdata.size(), GetNCCLType(type),
|
||||
GetNCCLRedOp(op), nccl->Handle(), nccl->Stream());
|
||||
return GetNCCLResult(rc);
|
||||
});
|
||||
}
|
||||
} << [&] { return nccl->Block(); };
|
||||
}
|
||||
|
||||
[[nodiscard]] Result NCCLColl::Broadcast(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) {
|
||||
if (!comm.IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||
CHECK(nccl);
|
||||
return Success() << [&] {
|
||||
return GetNCCLResult(ncclBroadcast(data.data(), data.data(), data.size_bytes(), ncclInt8, root,
|
||||
nccl->Handle(), nccl->Stream()));
|
||||
} << [&] { return nccl->Block(); };
|
||||
}
|
||||
|
||||
[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size) {
|
||||
if (!comm.IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||
CHECK(nccl);
|
||||
auto send = data.subspan(comm.Rank() * size, size);
|
||||
return Success() << [&] {
|
||||
return GetNCCLResult(
|
||||
ncclAllGather(send.data(), data.data(), size, ncclInt8, nccl->Handle(), nccl->Stream()));
|
||||
} << [&] { return nccl->Block(); };
|
||||
}
|
||||
|
||||
namespace cuda_impl {
|
||||
/**
|
||||
* @brief Implement allgather-v using broadcast.
|
||||
*
|
||||
* https://arxiv.org/abs/1812.05964
|
||||
*/
|
||||
Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes, common::Span<std::int8_t> recv) {
|
||||
return Success() << [] { return GetNCCLResult(ncclGroupStart()); } << [&] {
|
||||
std::size_t offset = 0;
|
||||
for (std::int32_t r = 0; r < comm->World(); ++r) {
|
||||
auto as_bytes = sizes[r];
|
||||
auto rc = ncclBroadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes,
|
||||
ncclInt8, r, comm->Handle(), dh::DefaultStream());
|
||||
if (rc != ncclSuccess) {
|
||||
return GetNCCLResult(rc);
|
||||
}
|
||||
offset += as_bytes;
|
||||
}
|
||||
return Success();
|
||||
} << [] { return GetNCCLResult(ncclGroupEnd()); };
|
||||
}
|
||||
} // namespace cuda_impl
|
||||
|
||||
[[nodiscard]] Result NCCLColl::AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments,
|
||||
common::Span<std::int8_t> recv, AllgatherVAlgo algo) {
|
||||
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||
CHECK(nccl);
|
||||
if (!comm.IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
|
||||
switch (algo) {
|
||||
case AllgatherVAlgo::kRing: {
|
||||
return Success() << [] { return GetNCCLResult(ncclGroupStart()); } << [&] {
|
||||
// get worker offset
|
||||
detail::AllgatherVOffset(sizes, recv_segments);
|
||||
// copy data
|
||||
auto current = recv.subspan(recv_segments[comm.Rank()], data.size_bytes());
|
||||
if (current.data() != data.data()) {
|
||||
dh::safe_cuda(cudaMemcpyAsync(current.data(), data.data(), current.size_bytes(),
|
||||
cudaMemcpyDeviceToDevice, nccl->Stream()));
|
||||
}
|
||||
return detail::RingAllgatherV(comm, sizes, recv_segments, recv);
|
||||
} << [] {
|
||||
return GetNCCLResult(ncclGroupEnd());
|
||||
} << [&] { return nccl->Block(); };
|
||||
}
|
||||
case AllgatherVAlgo::kBcast: {
|
||||
return cuda_impl::BroadcastAllgatherV(nccl, data, sizes, recv);
|
||||
}
|
||||
default: {
|
||||
return Fail("Unknown algorithm for allgather-v");
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
#endif // defined(XGBOOST_USE_NCCL)
|
||||
29
src/collective/coll.cuh
Normal file
29
src/collective/coll.cuh
Normal file
@ -0,0 +1,29 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cstdint> // for int8_t, int64_t
|
||||
|
||||
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
||||
#include "coll.h" // for Coll
|
||||
#include "comm.h" // for Comm
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
class NCCLColl : public Coll {
|
||||
public:
|
||||
~NCCLColl() override;
|
||||
|
||||
[[nodiscard]] Result Allreduce(Comm const& comm, common::Span<std::int8_t> data,
|
||||
ArrayInterfaceHandler::Type type, Op op) override;
|
||||
[[nodiscard]] Result Broadcast(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) override;
|
||||
[[nodiscard]] Result Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size) override;
|
||||
[[nodiscard]] Result AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments,
|
||||
common::Span<std::int8_t> recv, AllgatherVAlgo algo) override;
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
@ -2,17 +2,20 @@
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int8_t, int64_t
|
||||
#include <memory> // for enable_shared_from_this
|
||||
|
||||
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
||||
#include "comm.h" // for Comm
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
enum class AllgatherVAlgo {
|
||||
kRing = 0, // use ring-based allgather-v
|
||||
kBcast = 1, // use broadcast-based allgather-v
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Interface and base implementation for collective.
|
||||
*/
|
||||
@ -21,6 +24,8 @@ class Coll : public std::enable_shared_from_this<Coll> {
|
||||
Coll() = default;
|
||||
virtual ~Coll() noexcept(false) {} // NOLINT
|
||||
|
||||
Coll* MakeCUDAVar();
|
||||
|
||||
/**
|
||||
* @brief Allreduce
|
||||
*
|
||||
@ -29,8 +34,7 @@ class Coll : public std::enable_shared_from_this<Coll> {
|
||||
* @param [in] op Reduce operation. For custom operation, user needs to reach down to
|
||||
* the CPU implementation.
|
||||
*/
|
||||
[[nodiscard]] virtual Result Allreduce(Context const* ctx, Comm const& comm,
|
||||
common::Span<std::int8_t> data,
|
||||
[[nodiscard]] virtual Result Allreduce(Comm const& comm, common::Span<std::int8_t> data,
|
||||
ArrayInterfaceHandler::Type type, Op op);
|
||||
/**
|
||||
* @brief Broadcast
|
||||
@ -38,29 +42,29 @@ class Coll : public std::enable_shared_from_this<Coll> {
|
||||
* @param [in,out] data Data buffer for input and output.
|
||||
* @param [in] root Root rank for broadcast.
|
||||
*/
|
||||
[[nodiscard]] virtual Result Broadcast(Context const* ctx, Comm const& comm,
|
||||
common::Span<std::int8_t> data, std::int32_t root);
|
||||
[[nodiscard]] virtual Result Broadcast(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root);
|
||||
/**
|
||||
* @brief Allgather
|
||||
*
|
||||
* @param [in,out] data Data buffer for input and output.
|
||||
* @param [in] size Size of data for each worker.
|
||||
*/
|
||||
[[nodiscard]] virtual Result Allgather(Context const* ctx, Comm const& comm,
|
||||
common::Span<std::int8_t> data, std::size_t size);
|
||||
[[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size);
|
||||
/**
|
||||
* @brief Allgather with variable length.
|
||||
*
|
||||
* @param [in] data Input data for the current worker.
|
||||
* @param [in] sizes Size of the input from each worker.
|
||||
* @param [out] recv_segments pre-allocated offset for each worker in the output, size
|
||||
* should be equal to (world + 1).
|
||||
* @param [out] recv_segments pre-allocated offset buffer for each worker in the output,
|
||||
* size should be equal to (world + 1). GPU ring-based implementation
|
||||
* doesn't use the buffer.
|
||||
* @param [out] recv pre-allocated buffer for output.
|
||||
*/
|
||||
[[nodiscard]] virtual Result AllgatherV(Context const* ctx, Comm const& comm,
|
||||
common::Span<std::int8_t const> data,
|
||||
[[nodiscard]] virtual Result AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments,
|
||||
common::Span<std::int8_t> recv);
|
||||
common::Span<std::int8_t> recv, AllgatherVAlgo algo);
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -262,7 +262,7 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
|
||||
}
|
||||
|
||||
RabitComm::~RabitComm() noexcept(false) {
|
||||
if (!IsDistributed()) {
|
||||
if (!this->IsDistributed()) {
|
||||
return;
|
||||
}
|
||||
auto rc = this->Shutdown();
|
||||
|
||||
112
src/collective/comm.cu
Normal file
112
src/collective/comm.cu
Normal file
@ -0,0 +1,112 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
#include <algorithm> // for sort
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for uint64_t, int8_t
|
||||
#include <cstring> // for memcpy
|
||||
#include <memory> // for shared_ptr
|
||||
#include <sstream> // for stringstream
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/device_helpers.cuh" // for DefaultStream
|
||||
#include "../common/type.h" // for EraseType
|
||||
#include "broadcast.h" // for Broadcast
|
||||
#include "comm.cuh" // for NCCLComm
|
||||
#include "comm.h" // for Comm
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
Result GetUniqueId(Comm const& comm, ncclUniqueId* pid) {
|
||||
static const int kRootRank = 0;
|
||||
ncclUniqueId id;
|
||||
if (comm.Rank() == kRootRank) {
|
||||
dh::safe_nccl(ncclGetUniqueId(&id));
|
||||
}
|
||||
auto rc = Broadcast(comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)},
|
||||
kRootRank);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
*pid = id;
|
||||
return Success();
|
||||
}
|
||||
|
||||
inline constexpr std::size_t kUuidLength =
|
||||
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(std::uint64_t);
|
||||
|
||||
void GetCudaUUID(xgboost::common::Span<std::uint64_t, kUuidLength> const& uuid, DeviceOrd device) {
|
||||
cudaDeviceProp prob{};
|
||||
dh::safe_cuda(cudaGetDeviceProperties(&prob, device.ordinal));
|
||||
std::memcpy(uuid.data(), static_cast<void*>(&(prob.uuid)), sizeof(prob.uuid));
|
||||
}
|
||||
|
||||
static std::string PrintUUID(xgboost::common::Span<std::uint64_t, kUuidLength> const& uuid) {
|
||||
std::stringstream ss;
|
||||
for (auto v : uuid) {
|
||||
ss << std::hex << v;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) {
|
||||
return new NCCLComm{ctx, *this, pimpl};
|
||||
}
|
||||
|
||||
NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl)
|
||||
: Comm{root.TrackerInfo().host, root.TrackerInfo().port, root.Timeout(), root.Retry(),
|
||||
root.TaskID()},
|
||||
stream_{dh::DefaultStream()} {
|
||||
this->world_ = root.World();
|
||||
this->rank_ = root.Rank();
|
||||
this->domain_ = root.Domain();
|
||||
if (!root.IsDistributed()) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
|
||||
|
||||
std::vector<std::uint64_t> uuids(root.World() * kUuidLength, 0);
|
||||
auto s_uuid = xgboost::common::Span<std::uint64_t>{uuids.data(), uuids.size()};
|
||||
auto s_this_uuid = s_uuid.subspan(root.Rank() * kUuidLength, kUuidLength);
|
||||
GetCudaUUID(s_this_uuid, ctx->Device());
|
||||
|
||||
auto rc = pimpl->Allgather(root, common::EraseType(s_uuid), s_this_uuid.size_bytes());
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
|
||||
std::vector<xgboost::common::Span<std::uint64_t, kUuidLength>> converted(root.World());
|
||||
std::size_t j = 0;
|
||||
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
|
||||
converted[j] = s_uuid.subspan(i, kUuidLength);
|
||||
j++;
|
||||
}
|
||||
|
||||
std::sort(converted.begin(), converted.end());
|
||||
auto iter = std::unique(converted.begin(), converted.end());
|
||||
auto n_uniques = std::distance(converted.begin(), iter);
|
||||
|
||||
CHECK_EQ(n_uniques, root.World())
|
||||
<< "Multiple processes within communication group running on same CUDA "
|
||||
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||
|
||||
rc = GetUniqueId(root, &nccl_unique_id_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank()));
|
||||
|
||||
for (std::int32_t r = 0; r < root.World(); ++r) {
|
||||
this->channels_.emplace_back(
|
||||
std::make_shared<NCCLChannel>(root, r, nccl_comm_, dh::DefaultStream()));
|
||||
}
|
||||
}
|
||||
|
||||
NCCLComm::~NCCLComm() {
|
||||
if (nccl_comm_) {
|
||||
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
#endif // defined(XGBOOST_USE_NCCL)
|
||||
67
src/collective/comm.cuh
Normal file
67
src/collective/comm.cuh
Normal file
@ -0,0 +1,67 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
#include "nccl.h"
|
||||
#endif // XGBOOST_USE_NCCL
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "coll.h"
|
||||
#include "comm.h"
|
||||
#include "xgboost/context.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
|
||||
inline Result GetCUDAResult(cudaError rc) {
|
||||
if (rc == cudaSuccess) {
|
||||
return Success();
|
||||
}
|
||||
std::string msg = thrust::system_error(rc, thrust::cuda_category()).what();
|
||||
return Fail(msg);
|
||||
}
|
||||
|
||||
class NCCLComm : public Comm {
|
||||
ncclComm_t nccl_comm_{nullptr};
|
||||
ncclUniqueId nccl_unique_id_{};
|
||||
dh::CUDAStreamView stream_;
|
||||
|
||||
public:
|
||||
[[nodiscard]] ncclComm_t Handle() const { return nccl_comm_; }
|
||||
|
||||
explicit NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl);
|
||||
[[nodiscard]] Result LogTracker(std::string) const override {
|
||||
LOG(FATAL) << "Device comm is used for logging.";
|
||||
return Fail("Undefined.");
|
||||
}
|
||||
~NCCLComm() override;
|
||||
[[nodiscard]] bool IsFederated() const override { return false; }
|
||||
[[nodiscard]] dh::CUDAStreamView Stream() const { return stream_; }
|
||||
[[nodiscard]] Result Block() const override {
|
||||
auto rc = this->Stream().Sync(false);
|
||||
return GetCUDAResult(rc);
|
||||
}
|
||||
};
|
||||
|
||||
class NCCLChannel : public Channel {
|
||||
std::int32_t rank_{-1};
|
||||
ncclComm_t nccl_comm_{};
|
||||
dh::CUDAStreamView stream_;
|
||||
|
||||
public:
|
||||
explicit NCCLChannel(Comm const& comm, std::int32_t rank, ncclComm_t nccl_comm,
|
||||
dh::CUDAStreamView stream)
|
||||
: rank_{rank}, nccl_comm_{nccl_comm}, Channel{comm, nullptr}, stream_{stream} {}
|
||||
|
||||
void SendAll(std::int8_t const* ptr, std::size_t n) override {
|
||||
dh::safe_nccl(ncclSend(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
|
||||
}
|
||||
void RecvAll(std::int8_t* ptr, std::size_t n) override {
|
||||
dh::safe_nccl(ncclRecv(ptr, n, ncclInt8, rank_, nccl_comm_, stream_));
|
||||
}
|
||||
[[nodiscard]] Result Block() override {
|
||||
auto rc = stream_.Sync(false);
|
||||
return GetCUDAResult(rc);
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
@ -8,7 +8,6 @@
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
#include <type_traits> // for remove_const_t
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
@ -16,6 +15,7 @@
|
||||
#include "protocol.h" // for PeerInfo
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
@ -35,13 +35,14 @@ inline std::int32_t BootstrapPrev(std::int32_t r, std::int32_t world) {
|
||||
}
|
||||
|
||||
class Channel;
|
||||
class Coll;
|
||||
|
||||
/**
|
||||
* @brief Base communicator storing info about the tracker and other communicators.
|
||||
*/
|
||||
class Comm {
|
||||
protected:
|
||||
std::int32_t world_{1};
|
||||
std::int32_t world_{-1};
|
||||
std::int32_t rank_{0};
|
||||
std::chrono::seconds timeout_{DefaultTimeoutSec()};
|
||||
std::int32_t retry_{DefaultRetry()};
|
||||
@ -69,12 +70,14 @@ class Comm {
|
||||
[[nodiscard]] Result ConnectTracker(TCPSocket* out) const;
|
||||
[[nodiscard]] auto Domain() const { return domain_; }
|
||||
[[nodiscard]] auto Timeout() const { return timeout_; }
|
||||
[[nodiscard]] auto Retry() const { return retry_; }
|
||||
[[nodiscard]] auto TaskID() const { return task_id_; }
|
||||
|
||||
[[nodiscard]] auto Rank() const { return rank_; }
|
||||
[[nodiscard]] auto World() const { return world_; }
|
||||
[[nodiscard]] bool IsDistributed() const { return World() > 1; }
|
||||
[[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; }
|
||||
[[nodiscard]] bool IsDistributed() const { return world_ != -1; }
|
||||
void Submit(Loop::Op op) const { loop_->Submit(op); }
|
||||
[[nodiscard]] Result Block() const { return loop_->Block(); }
|
||||
[[nodiscard]] virtual Result Block() const { return loop_->Block(); }
|
||||
|
||||
[[nodiscard]] virtual std::shared_ptr<Channel> Chan(std::int32_t rank) const {
|
||||
return channels_.at(rank);
|
||||
@ -83,6 +86,8 @@ class Comm {
|
||||
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;
|
||||
|
||||
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
|
||||
|
||||
Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl);
|
||||
};
|
||||
|
||||
class RabitComm : public Comm {
|
||||
@ -116,7 +121,7 @@ class Channel {
|
||||
explicit Channel(Comm const& comm, std::shared_ptr<TCPSocket> sock)
|
||||
: sock_{std::move(sock)}, comm_{comm} {}
|
||||
|
||||
void SendAll(std::int8_t const* ptr, std::size_t n) {
|
||||
virtual void SendAll(std::int8_t const* ptr, std::size_t n) {
|
||||
Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast<std::int8_t*>(ptr), n, sock_.get(), 0};
|
||||
CHECK(sock_.get());
|
||||
comm_.Submit(std::move(op));
|
||||
@ -125,7 +130,7 @@ class Channel {
|
||||
this->SendAll(data.data(), data.size_bytes());
|
||||
}
|
||||
|
||||
void RecvAll(std::int8_t* ptr, std::size_t n) {
|
||||
virtual void RecvAll(std::int8_t* ptr, std::size_t n) {
|
||||
Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0};
|
||||
CHECK(sock_.get());
|
||||
comm_.Submit(std::move(op));
|
||||
@ -133,7 +138,7 @@ class Channel {
|
||||
void RecvAll(common::Span<std::int8_t> data) { this->RecvAll(data.data(), data.size_bytes()); }
|
||||
|
||||
[[nodiscard]] auto Socket() const { return sock_; }
|
||||
[[nodiscard]] Result Block() { return comm_.Block(); }
|
||||
[[nodiscard]] virtual Result Block() { return comm_.Block(); }
|
||||
};
|
||||
|
||||
enum class Op { kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5 };
|
||||
|
||||
@ -1169,7 +1169,13 @@ class CUDAStreamView {
|
||||
operator cudaStream_t() const { // NOLINT
|
||||
return stream_;
|
||||
}
|
||||
void Sync() { dh::safe_cuda(cudaStreamSynchronize(stream_)); }
|
||||
cudaError_t Sync(bool error = true) {
|
||||
if (error) {
|
||||
dh::safe_cuda(cudaStreamSynchronize(stream_));
|
||||
return cudaSuccess;
|
||||
}
|
||||
return cudaStreamSynchronize(stream_);
|
||||
}
|
||||
};
|
||||
|
||||
inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT
|
||||
|
||||
@ -19,7 +19,6 @@
|
||||
#include "../common/cuda_context.cuh" // CUDAContext
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/timer.h"
|
||||
#include "../data/ellpack_page.cuh"
|
||||
#include "../data/ellpack_page.h"
|
||||
@ -39,7 +38,6 @@
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/parameter.h"
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/task.h" // for ObjInfo
|
||||
#include "xgboost/tree_model.h"
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../src/collective/allgather.h" // for RingAllgather
|
||||
#include "../../../src/collective/coll.h" // for Coll
|
||||
#include "../../../src/collective/comm.h" // for RabitComm
|
||||
#include "gtest/gtest.h" // for AssertionR...
|
||||
#include "test_worker.h" // for TestDistri...
|
||||
@ -63,25 +64,7 @@ class Worker : public WorkerForTest {
|
||||
}
|
||||
}
|
||||
|
||||
void TestV() {
|
||||
{
|
||||
// basic test
|
||||
std::int32_t n{comm_.Rank()};
|
||||
std::vector<std::int32_t> result;
|
||||
auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
for (std::int32_t i = 0; i < comm_.World(); ++i) {
|
||||
ASSERT_EQ(result[i], i);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// V test
|
||||
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
|
||||
std::vector<std::int32_t> result;
|
||||
auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2);
|
||||
void CheckV(common::Span<std::int32_t> result) {
|
||||
std::int32_t k{0};
|
||||
for (std::int32_t r = 0; r < comm_.World(); ++r) {
|
||||
auto seg = common::Span{result.data(), result.size()}.subspan(k, (r + 1));
|
||||
@ -93,6 +76,66 @@ class Worker : public WorkerForTest {
|
||||
}
|
||||
}
|
||||
}
|
||||
void TestVRing() {
|
||||
// V test
|
||||
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
|
||||
std::vector<std::int32_t> result;
|
||||
auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2);
|
||||
CheckV(result);
|
||||
}
|
||||
|
||||
void TestVBasic() {
|
||||
// basic test
|
||||
std::int32_t n{comm_.Rank()};
|
||||
std::vector<std::int32_t> result;
|
||||
auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
for (std::int32_t i = 0; i < comm_.World(); ++i) {
|
||||
ASSERT_EQ(result[i], i);
|
||||
}
|
||||
}
|
||||
|
||||
void TestVAlgo() {
|
||||
// V test, broadcast
|
||||
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
|
||||
auto s_data = common::Span{data.data(), data.size()};
|
||||
|
||||
std::vector<std::int64_t> sizes(comm_.World(), 0);
|
||||
sizes[comm_.Rank()] = s_data.size_bytes();
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
std::shared_ptr<Coll> pcoll{new Coll{}};
|
||||
|
||||
std::vector<std::int64_t> recv_segments(comm_.World() + 1, 0);
|
||||
std::vector<std::int32_t> recv(std::accumulate(sizes.cbegin(), sizes.cend(), 0));
|
||||
|
||||
auto s_recv = common::Span{recv.data(), recv.size()};
|
||||
|
||||
rc = pcoll->AllgatherV(comm_, common::EraseType(s_data),
|
||||
common::Span{sizes.data(), sizes.size()},
|
||||
common::Span{recv_segments.data(), recv_segments.size()},
|
||||
common::EraseType(s_recv), AllgatherVAlgo::kBcast);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
CheckV(s_recv);
|
||||
|
||||
// Test inplace
|
||||
auto test_inplace = [&] (AllgatherVAlgo algo) {
|
||||
std::fill_n(s_recv.data(), s_recv.size(), 0);
|
||||
auto current = s_recv.subspan(recv_segments[comm_.Rank()],
|
||||
recv_segments[comm_.Rank() + 1] - recv_segments[comm_.Rank()]);
|
||||
std::copy_n(data.data(), data.size(), current.data());
|
||||
rc = pcoll->AllgatherV(comm_, common::EraseType(current),
|
||||
common::Span{sizes.data(), sizes.size()},
|
||||
common::Span{recv_segments.data(), recv_segments.size()},
|
||||
common::EraseType(s_recv), algo);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
CheckV(s_recv);
|
||||
};
|
||||
|
||||
test_inplace(AllgatherVAlgo::kBcast);
|
||||
test_inplace(AllgatherVAlgo::kRing);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
@ -106,12 +149,30 @@ TEST_F(AllgatherTest, Basic) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(AllgatherTest, V) {
|
||||
TEST_F(AllgatherTest, VBasic) {
|
||||
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Worker worker{host, port, timeout, n_workers, r};
|
||||
worker.TestV();
|
||||
worker.TestVBasic();
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(AllgatherTest, VRing) {
|
||||
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Worker worker{host, port, timeout, n_workers, r};
|
||||
worker.TestVRing();
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(AllgatherTest, VAlgo) {
|
||||
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Worker worker{host, port, timeout, n_workers, r};
|
||||
worker.TestVAlgo();
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
117
tests/cpp/collective/test_allgather.cu
Normal file
117
tests/cpp/collective/test_allgather.cu
Normal file
@ -0,0 +1,117 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
#include <gtest/gtest.h>
|
||||
#include <thrust/device_vector.h> // for device_vector
|
||||
#include <thrust/equal.h> // for equal
|
||||
#include <xgboost/span.h> // for Span
|
||||
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t, int64_t
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../src/collective/allgather.h" // for RingAllgather
|
||||
#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector
|
||||
#include "../../../src/common/type.h" // for EraseType
|
||||
#include "test_worker.cuh" // for NCCLWorkerForTest
|
||||
#include "test_worker.h" // for TestDistributed, WorkerForTest
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class Worker : public NCCLWorkerForTest {
|
||||
public:
|
||||
using NCCLWorkerForTest::NCCLWorkerForTest;
|
||||
|
||||
void TestV(AllgatherVAlgo algo) {
|
||||
{
|
||||
// basic test
|
||||
std::size_t n = 1;
|
||||
// create data
|
||||
dh::device_vector<std::int32_t> data(n, comm_.Rank());
|
||||
auto s_data = common::EraseType(common::Span{data.data().get(), data.size()});
|
||||
// get size
|
||||
std::vector<std::int64_t> sizes(comm_.World(), -1);
|
||||
sizes[comm_.Rank()] = s_data.size_bytes();
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
// create result
|
||||
dh::device_vector<std::int32_t> result(comm_.World(), -1);
|
||||
auto s_result = common::EraseType(dh::ToSpan(result));
|
||||
|
||||
std::vector<std::int64_t> recv_seg(nccl_comm_->World() + 1, 0);
|
||||
rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()},
|
||||
common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
|
||||
for (std::int32_t i = 0; i < comm_.World(); ++i) {
|
||||
ASSERT_EQ(result[i], i);
|
||||
}
|
||||
}
|
||||
{
|
||||
// V test
|
||||
std::size_t n = 256 * 256;
|
||||
// create data
|
||||
dh::device_vector<std::int32_t> data(n * nccl_comm_->Rank(), nccl_comm_->Rank());
|
||||
auto s_data = common::EraseType(common::Span{data.data().get(), data.size()});
|
||||
// get size
|
||||
std::vector<std::int64_t> sizes(nccl_comm_->World(), 0);
|
||||
sizes[comm_.Rank()] = dh::ToSpan(data).size_bytes();
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
auto n_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
|
||||
// create result
|
||||
dh::device_vector<std::int32_t> result(n_bytes / sizeof(std::int32_t), -1);
|
||||
auto s_result = common::EraseType(dh::ToSpan(result));
|
||||
|
||||
std::vector<std::int64_t> recv_seg(nccl_comm_->World() + 1, 0);
|
||||
rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()},
|
||||
common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
// check segment size
|
||||
if (algo != AllgatherVAlgo::kBcast) {
|
||||
auto size = recv_seg[nccl_comm_->Rank() + 1] - recv_seg[nccl_comm_->Rank()];
|
||||
ASSERT_EQ(size, n * nccl_comm_->Rank() * sizeof(std::int32_t));
|
||||
ASSERT_EQ(size, sizes[nccl_comm_->Rank()]);
|
||||
}
|
||||
// check data
|
||||
std::size_t k{0};
|
||||
for (std::int32_t r = 0; r < nccl_comm_->World(); ++r) {
|
||||
std::size_t s = n * r;
|
||||
auto current = dh::ToSpan(result).subspan(k, s);
|
||||
std::vector<std::int32_t> h_data(current.size());
|
||||
dh::CopyDeviceSpanToVector(&h_data, current);
|
||||
for (auto v : h_data) {
|
||||
ASSERT_EQ(v, r);
|
||||
}
|
||||
k += s;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class AllgatherTestGPU : public SocketTest {};
|
||||
} // namespace
|
||||
|
||||
TEST_F(AllgatherTestGPU, MGPUTestVRing) {
|
||||
auto n_workers = common::AllVisibleGPUs();
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Worker w{host, port, timeout, n_workers, r};
|
||||
w.Setup();
|
||||
w.TestV(AllgatherVAlgo::kRing);
|
||||
w.TestV(AllgatherVAlgo::kBcast);
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(AllgatherTestGPU, MGPUTestVBcast) {
|
||||
auto n_workers = common::AllVisibleGPUs();
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Worker w{host, port, timeout, n_workers, r};
|
||||
w.Setup();
|
||||
w.TestV(AllgatherVAlgo::kBcast);
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
#endif // defined(XGBOOST_USE_NCCL)
|
||||
@ -6,10 +6,10 @@
|
||||
#include "../../../src/collective/allreduce.h"
|
||||
#include "../../../src/collective/coll.h" // for Coll
|
||||
#include "../../../src/collective/tracker.h"
|
||||
#include "../../../src/common/type.h" // for EraseType
|
||||
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
||||
|
||||
namespace xgboost::collective {
|
||||
|
||||
namespace {
|
||||
class AllreduceWorker : public WorkerForTest {
|
||||
public:
|
||||
@ -50,11 +50,10 @@ class AllreduceWorker : public WorkerForTest {
|
||||
}
|
||||
|
||||
void BitOr() {
|
||||
Context ctx;
|
||||
std::vector<std::uint32_t> data(comm_.World(), 0);
|
||||
data[comm_.Rank()] = ~std::uint32_t{0};
|
||||
auto pcoll = std::shared_ptr<Coll>{new Coll{}};
|
||||
auto rc = pcoll->Allreduce(&ctx, comm_, EraseType(common::Span{data.data(), data.size()}),
|
||||
auto rc = pcoll->Allreduce(comm_, common::EraseType(common::Span{data.data(), data.size()}),
|
||||
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
for (auto v : data) {
|
||||
|
||||
70
tests/cpp/collective/test_allreduce.cu
Normal file
70
tests/cpp/collective/test_allreduce.cu
Normal file
@ -0,0 +1,70 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
#include <gtest/gtest.h>
|
||||
#include <thrust/host_vector.h> // for host_vector
|
||||
|
||||
#include "../../../src/collective/coll.h" // for Coll
|
||||
#include "../../../src/common/common.h"
|
||||
#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector
|
||||
#include "../../../src/common/type.h" // for EraseType
|
||||
#include "../helpers.h" // for MakeCUDACtx
|
||||
#include "test_worker.cuh" // for NCCLWorkerForTest
|
||||
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class AllreduceTestGPU : public SocketTest {};
|
||||
|
||||
class Worker : public NCCLWorkerForTest {
|
||||
public:
|
||||
using NCCLWorkerForTest::NCCLWorkerForTest;
|
||||
|
||||
void BitOr() {
|
||||
dh::device_vector<std::uint32_t> data(comm_.World(), 0);
|
||||
data[comm_.Rank()] = ~std::uint32_t{0};
|
||||
auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)),
|
||||
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
thrust::host_vector<std::uint32_t> h_data(data.size());
|
||||
thrust::copy(data.cbegin(), data.cend(), h_data.begin());
|
||||
for (auto v : h_data) {
|
||||
ASSERT_EQ(v, ~std::uint32_t{0});
|
||||
}
|
||||
}
|
||||
|
||||
void Acc() {
|
||||
dh::device_vector<double> data(314, 1.5);
|
||||
auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)),
|
||||
ArrayInterfaceHandler::kF8, Op::kSum);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
for (std::size_t i = 0; i < data.size(); ++i) {
|
||||
auto v = data[i];
|
||||
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST_F(AllreduceTestGPU, BitOr) {
|
||||
auto n_workers = common::AllVisibleGPUs();
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Worker w{host, port, timeout, n_workers, r};
|
||||
w.Setup();
|
||||
w.BitOr();
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(AllreduceTestGPU, Sum) {
|
||||
auto n_workers = common::AllVisibleGPUs();
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Worker w{host, port, timeout, n_workers, r};
|
||||
w.Setup();
|
||||
w.Acc();
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
#endif // defined(XGBOOST_USE_NCCL)
|
||||
@ -47,5 +47,5 @@ TEST_F(BroadcastTest, Basic) {
|
||||
Worker worker{host, port, timeout, n_workers, r};
|
||||
worker.Run();
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
} // namespace xgboost::collective
|
||||
|
||||
32
tests/cpp/collective/test_worker.cuh
Normal file
32
tests/cpp/collective/test_worker.cuh
Normal file
@ -0,0 +1,32 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <memory> // for shared_ptr
|
||||
|
||||
#include "../../../src/collective/coll.h" // for Coll
|
||||
#include "../../../src/collective/comm.h" // for Comm
|
||||
#include "test_worker.h"
|
||||
#include "xgboost/context.h" // for Context
|
||||
|
||||
namespace xgboost::collective {
|
||||
class NCCLWorkerForTest : public WorkerForTest {
|
||||
protected:
|
||||
std::shared_ptr<Coll> coll_;
|
||||
std::shared_ptr<xgboost::collective::Comm> nccl_comm_;
|
||||
std::shared_ptr<Coll> nccl_coll_;
|
||||
Context ctx_;
|
||||
|
||||
public:
|
||||
using WorkerForTest::WorkerForTest;
|
||||
|
||||
void Setup() {
|
||||
ctx_ = MakeCUDACtx(comm_.Rank());
|
||||
coll_.reset(new Coll{});
|
||||
nccl_comm_.reset(this->comm_.MakeCUDAVar(&ctx_, coll_));
|
||||
nccl_coll_.reset(coll_->MakeCUDAVar());
|
||||
ASSERT_EQ(comm_.World(), nccl_comm_->World());
|
||||
ASSERT_EQ(comm_.Rank(), nccl_comm_->Rank());
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
@ -1,6 +1,7 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <chrono> // for seconds
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user