[coll] Add global functions. (#10203)

This commit is contained in:
Jiaming Yuan 2024-04-19 03:17:23 +08:00 committed by GitHub
parent 551fa6e25e
commit 3f64b4fde3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 283 additions and 69 deletions

View File

@ -104,6 +104,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/comm_group.o \
$(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/communicator-inl.o \
$(PKGROOT)/src/collective/tracker.o \

View File

@ -104,6 +104,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/comm_group.o \
$(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/communicator-inl.o \
$(PKGROOT)/src/collective/tracker.o \

View File

@ -165,7 +165,7 @@ template <typename T>
T GlobalRatio(Context const* ctx, MetaInfo const& info, T dividend, T divisor) {
std::array<T, 2> results{dividend, divisor};
auto rc = GlobalSum(ctx, info, linalg::MakeVec(results.data(), results.size()));
collective::SafeColl(rc);
SafeColl(rc);
std::tie(dividend, divisor) = std::tuple_cat(results);
if (divisor <= 0) {
return std::numeric_limits<T>::quiet_NaN();

View File

@ -33,6 +33,7 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
bool is_last_segment = send_rank == (world - 1);
auto send_nbytes = is_last_segment ? (data.size_bytes() - send_off) : segment_size;
auto send_seg = data.subspan(send_off, send_nbytes);
CHECK_NE(send_seg.size(), 0);
return next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
} << [&] {
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
@ -40,9 +41,10 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
bool is_last_segment = recv_rank == (world - 1);
auto recv_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : segment_size;
auto recv_seg = data.subspan(recv_off, recv_nbytes);
CHECK_NE(recv_seg.size(), 0);
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
} << [&] {
return prev_ch->Block();
return comm.Block();
};
if (!rc.OK()) {
return rc;
@ -106,4 +108,47 @@ namespace detail {
return comm.Block();
}
} // namespace detail
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
Context const* ctx, CommGroup const& comm, std::vector<std::vector<char>> const& input) {
auto n_inputs = input.size();
std::vector<std::int64_t> sizes(n_inputs);
std::transform(input.cbegin(), input.cend(), sizes.begin(),
[](auto const& vec) { return vec.size(); });
std::vector<std::int64_t> recv_segments(comm.World() + 1, 0);
HostDeviceVector<std::int8_t> recv;
auto rc =
AllgatherV(ctx, comm, linalg::MakeVec(sizes.data(), sizes.size()), &recv_segments, &recv);
SafeColl(rc);
auto global_sizes = common::RestoreType<std::int64_t const>(recv.ConstHostSpan());
std::vector<std::int64_t> offset(global_sizes.size() + 1);
offset[0] = 0;
for (std::size_t i = 1; i < offset.size(); i++) {
offset[i] = offset[i - 1] + global_sizes[i - 1];
}
std::vector<char> collected;
for (auto const& vec : input) {
collected.insert(collected.end(), vec.cbegin(), vec.cend());
}
rc = AllgatherV(ctx, comm, linalg::MakeVec(collected.data(), collected.size()), &recv_segments,
&recv);
SafeColl(rc);
auto out = common::RestoreType<char const>(recv.ConstHostSpan());
std::vector<std::vector<char>> result;
for (std::size_t i = 1; i < offset.size(); ++i) {
std::vector<char> local(out.cbegin() + offset[i - 1], out.cbegin() + offset[i]);
result.emplace_back(std::move(local));
}
return result;
}
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
Context const* ctx, std::vector<std::vector<char>> const& input) {
return VectorAllgatherV(ctx, *GlobalCommGroup(), input);
}
} // namespace xgboost::collective

View File

@ -102,4 +102,115 @@ template <typename T>
return detail::RingAllgatherV(comm, sizes, s_segments, erased_result);
}
template <typename T>
[[nodiscard]] Result Allgather(Context const* ctx, CommGroup const& comm,
linalg::VectorView<T> data) {
if (!comm.IsDistributed()) {
return Success();
}
CHECK(data.Contiguous());
auto erased = common::EraseType(data.Values());
auto const& cctx = comm.Ctx(ctx, data.Device());
auto backend = comm.Backend(data.Device());
return backend->Allgather(cctx, erased);
}
/**
* @brief Gather all data from all workers.
*
* @param data The input and output buffer, needs to be pre-allocated by the caller.
*/
template <typename T>
[[nodiscard]] Result Allgather(Context const* ctx, linalg::VectorView<T> data) {
auto const& cg = *GlobalCommGroup();
if (data.Size() % cg.World() != 0) {
return Fail("The total number of elements should be multiple of the number of workers.");
}
return Allgather(ctx, cg, data);
}
template <typename T>
[[nodiscard]] Result AllgatherV(Context const* ctx, CommGroup const& comm,
linalg::VectorView<T> data,
std::vector<std::int64_t>* recv_segments,
HostDeviceVector<std::int8_t>* recv) {
if (!comm.IsDistributed()) {
return Success();
}
std::vector<std::int64_t> sizes(comm.World(), 0);
sizes[comm.Rank()] = data.Values().size_bytes();
auto erased_sizes = common::EraseType(common::Span{sizes.data(), sizes.size()});
auto rc = comm.Backend(DeviceOrd::CPU())
->Allgather(comm.Ctx(ctx, DeviceOrd::CPU()), erased_sizes);
if (!rc.OK()) {
return rc;
}
recv_segments->resize(sizes.size() + 1);
detail::AllgatherVOffset(sizes, common::Span{recv_segments->data(), recv_segments->size()});
auto total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0LL);
recv->SetDevice(data.Device());
recv->Resize(total_bytes);
auto s_segments = common::Span{recv_segments->data(), recv_segments->size()};
auto backend = comm.Backend(data.Device());
auto erased = common::EraseType(data.Values());
return backend->AllgatherV(
comm.Ctx(ctx, data.Device()), erased, common::Span{sizes.data(), sizes.size()}, s_segments,
data.Device().IsCUDA() ? recv->DeviceSpan() : recv->HostSpan(), AllgatherVAlgo::kBcast);
}
/**
* @brief Allgather with variable length data.
*
* @param data The input data.
* @param recv_segments segment size for each worker. [0, 2, 5] means [0, 2) elements are
* from the first worker, [2, 5) elements are from the second one.
* @param recv The buffer storing the result.
*/
template <typename T>
[[nodiscard]] Result AllgatherV(Context const* ctx, linalg::VectorView<T> data,
std::vector<std::int64_t>* recv_segments,
HostDeviceVector<std::int8_t>* recv) {
return AllgatherV(ctx, *GlobalCommGroup(), data, recv_segments, recv);
}
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
Context const* ctx, CommGroup const& comm, std::vector<std::vector<char>> const& input);
/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
*
* @param inputs All the inputs from the local worker. The number of inputs can vary
* across different workers. Along with which, the size of each vector in
* the input can also vary.
*
* @return The AllgatherV result, containing vectors from all workers.
*/
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
Context const* ctx, std::vector<std::vector<char>> const& input);
/**
* @brief Gathers variable-length strings from all processes and distributes them to all processes.
* @param input Variable-length list of variable-length strings.
*/
[[nodiscard]] inline Result AllgatherStrings(std::vector<std::string> const& input,
std::vector<std::string>* p_result) {
std::vector<std::vector<char>> inputs(input.size());
for (std::size_t i = 0; i < input.size(); ++i) {
inputs[i] = {input[i].cbegin(), input[i].cend()};
}
Context ctx;
auto out = VectorAllgatherV(&ctx, *GlobalCommGroup(), inputs);
auto& result = *p_result;
result.resize(out.size());
for (std::size_t i = 0; i < out.size(); ++i) {
result[i] = {out[i].cbegin(), out[i].cend()};
}
return Success();
}
} // namespace xgboost::collective

View File

@ -68,6 +68,8 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
auto s_buf = common::Span{buffer.data(), buffer.size()};
for (std::int32_t r = 0; r < world - 1; ++r) {
common::Span<std::int8_t> seg, recv_seg;
auto rc = Success() << [&] {
// send to ring next
auto send_rank = (rank + world - r) % world;
auto send_off = send_rank * n_bytes_in_seg;
@ -75,32 +77,26 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
bool is_last_segment = send_rank == (world - 1);
auto seg_nbytes = is_last_segment ? data.size_bytes() - send_off : n_bytes_in_seg;
CHECK_EQ(seg_nbytes % sizeof(T), 0);
auto send_seg = data.subspan(send_off, seg_nbytes);
auto rc = next_ch->SendAll(send_seg);
if (!rc.OK()) {
return rc;
}
return next_ch->SendAll(send_seg);
} << [&] {
// receive from ring prev
auto recv_rank = (rank + world - r - 1) % world;
auto recv_off = recv_rank * n_bytes_in_seg;
is_last_segment = recv_rank == (world - 1);
bool is_last_segment = recv_rank == (world - 1);
seg_nbytes = is_last_segment ? data.size_bytes() - recv_off : n_bytes_in_seg;
auto seg_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : n_bytes_in_seg;
CHECK_EQ(seg_nbytes % sizeof(T), 0);
auto recv_seg = data.subspan(recv_off, seg_nbytes);
auto seg = s_buf.subspan(0, recv_seg.size());
rc = std::move(rc) << [&] {
recv_seg = data.subspan(recv_off, seg_nbytes);
seg = s_buf.subspan(0, recv_seg.size());
return prev_ch->RecvAll(seg);
} << [&] {
return comm.Block();
};
if (!rc.OK()) {
return rc;
}
// accumulate to recv_seg
CHECK_EQ(seg.size(), recv_seg.size());

View File

@ -1,15 +1,18 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int8_t
#include <functional> // for function
#include <type_traits> // for is_invocable_v, enable_if_t
#include <vector> // for vector
#include "../common/type.h" // for EraseType, RestoreType
#include "../data/array_interface.h" // for ArrayInterfaceHandler
#include "../data/array_interface.h" // for ToDType, ArrayInterfaceHandler
#include "comm.h" // for Comm, RestoreType
#include "comm_group.h" // for GlobalCommGroup
#include "xgboost/collective/result.h" // for Result
#include "xgboost/context.h" // for Context
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
@ -27,8 +30,7 @@ std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>
auto erased = common::EraseType(data);
auto type = ToDType<T>::kType;
auto erased_fn = [type, redop](common::Span<std::int8_t const> lhs,
common::Span<std::int8_t> out) {
auto erased_fn = [redop](common::Span<std::int8_t const> lhs, common::Span<std::int8_t> out) {
CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction.";
auto lhs_t = common::RestoreType<T const>(lhs);
auto rhs_t = common::RestoreType<T>(out);
@ -37,4 +39,40 @@ std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>
return cpu_impl::RingAllreduce(comm, erased, erased_fn, type);
}
template <typename T, std::int32_t kDim>
[[nodiscard]] Result Allreduce(Context const* ctx, CommGroup const& comm,
linalg::TensorView<T, kDim> data, Op op) {
if (!comm.IsDistributed()) {
return Success();
}
CHECK(data.Contiguous());
auto erased = common::EraseType(data.Values());
auto type = ToDType<T>::kType;
auto backend = comm.Backend(data.Device());
return backend->Allreduce(comm.Ctx(ctx, data.Device()), erased, type, op);
}
template <typename T, std::int32_t kDim>
[[nodiscard]] Result Allreduce(Context const* ctx, linalg::TensorView<T, kDim> data, Op op) {
return Allreduce(ctx, *GlobalCommGroup(), data, op);
}
/**
* @brief Specialization for std::vector.
*/
template <typename T, typename Alloc>
[[nodiscard]] Result Allreduce(Context const* ctx, std::vector<T, Alloc>* data, Op op) {
return Allreduce(ctx, linalg::MakeVec(data->data(), data->size()), op);
}
/**
* @brief Specialization for scalar value.
*/
template <typename T>
[[nodiscard]] std::enable_if_t<std::is_standard_layout_v<T> && std::is_trivial_v<T>, Result>
Allreduce(Context const* ctx, T* data, Op op) {
return Allreduce(ctx, linalg::MakeVec(data, 1), op);
}
} // namespace xgboost::collective

View File

@ -1,11 +1,15 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int32_t, int8_t
#include "comm.h" // for Comm
#include "xgboost/collective/result.h" // for
#include "../common/type.h"
#include "comm.h" // for Comm, EraseType
#include "comm_group.h" // for CommGroup
#include "xgboost/collective/result.h" // for Result
#include "xgboost/context.h" // for Context
#include "xgboost/linalg.h" // for VectorView
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
@ -23,4 +27,21 @@ template <typename T>
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes};
return cpu_impl::Broadcast(comm, erased, root);
}
template <typename T>
[[nodiscard]] Result Broadcast(Context const* ctx, CommGroup const& comm,
linalg::VectorView<T> data, std::int32_t root) {
if (!comm.IsDistributed()) {
return Success();
}
CHECK(data.Contiguous());
auto erased = common::EraseType(data.Values());
auto backend = comm.Backend(data.Device());
return backend->Broadcast(comm.Ctx(ctx, data.Device()), erased, root);
}
template <typename T>
[[nodiscard]] Result Broadcast(Context const* ctx, linalg::VectorView<T> data, std::int32_t root) {
return Broadcast(ctx, *GlobalCommGroup(), data, root);
}
} // namespace xgboost::collective

View File

@ -27,7 +27,7 @@ Result GetUniqueId(Comm const& comm, std::shared_ptr<NcclStub> stub, std::shared
ncclUniqueId id;
if (comm.Rank() == kRootRank) {
auto rc = stub->GetUniqueId(&id);
CHECK(rc.OK()) << rc.Report();
SafeColl(rc);
}
auto rc = coll->Broadcast(
comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)}, kRootRank);
@ -81,8 +81,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
GetCudaUUID(s_this_uuid, ctx->Device());
auto rc = pimpl->Allgather(root, common::EraseType(s_uuid));
CHECK(rc.OK()) << rc.Report();
SafeColl(rc);
std::vector<xgboost::common::Span<std::uint64_t, kUuidLength>> converted(root.World());
std::size_t j = 0;
@ -103,7 +102,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
[&] {
return this->stub_->CommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank());
};
CHECK(rc.OK()) << rc.Report();
SafeColl(rc);
for (std::int32_t r = 0; r < root.World(); ++r) {
this->channels_.emplace_back(
@ -114,7 +113,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
NCCLComm::~NCCLComm() {
if (nccl_comm_) {
auto rc = stub_->CommDestroy(nccl_comm_);
CHECK(rc.OK()) << rc.Report();
SafeColl(rc);
}
}
} // namespace xgboost::collective

View File

@ -12,7 +12,6 @@
#include <thrust/iterator/transform_output_iterator.h> // make_transform_output_iterator
#include <thrust/logical.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#include <thrust/transform_scan.h>

View File

@ -8,6 +8,7 @@
#define COMMON_HIST_UTIL_CUH_
#include <thrust/host_vector.h>
#include <thrust/sort.h> // for sort
#include <cstddef> // for size_t

View File

@ -1,13 +1,13 @@
/**
* Copyright 2019-2023 by XGBoost Contributors
* Copyright 2019-2024, XGBoost Contributors
*/
#include <thrust/functional.h>
#include <thrust/random.h>
#include <thrust/sort.h> // for sort
#include <thrust/transform.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/logging.h>
#include <algorithm>
#include <cstddef> // for size_t
#include <limits>
#include <utility>

View File

@ -71,7 +71,7 @@ target_include_directories(testxgboost
${xgboost_SOURCE_DIR}/rabit/include)
target_link_libraries(testxgboost
PRIVATE
${GTEST_LIBRARIES})
GTest::gtest GTest::gmock)
set_output_directory(testxgboost ${xgboost_BINARY_DIR})

View File

@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include <gtest/gtest.h> // for ASSERT_EQ
#include <xgboost/span.h> // for Span, oper...
@ -35,7 +35,7 @@ class Worker : public WorkerForTest {
data[comm_.Rank()] = comm_.Rank();
auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()});
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
for (std::int32_t r = 0; r < comm_.World(); ++r) {
ASSERT_EQ(data[r], r);
@ -52,7 +52,7 @@ class Worker : public WorkerForTest {
std::iota(seg.begin(), seg.end(), comm_.Rank());
auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()});
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
for (std::int32_t r = 0; r < comm_.World(); ++r) {
auto seg = s_data.subspan(r * n, n);
@ -81,7 +81,7 @@ class Worker : public WorkerForTest {
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
std::vector<std::int32_t> result;
auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2);
CheckV(result);
}
@ -91,7 +91,7 @@ class Worker : public WorkerForTest {
std::int32_t n{comm_.Rank()};
std::vector<std::int32_t> result;
auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
for (std::int32_t i = 0; i < comm_.World(); ++i) {
ASSERT_EQ(result[i], i);
}
@ -105,7 +105,7 @@ class Worker : public WorkerForTest {
std::vector<std::int64_t> sizes(comm_.World(), 0);
sizes[comm_.Rank()] = s_data.size_bytes();
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()});
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
std::shared_ptr<Coll> pcoll{new Coll{}};
std::vector<std::int64_t> recv_segments(comm_.World() + 1, 0);

View File

@ -34,7 +34,7 @@ class Worker : public NCCLWorkerForTest {
std::vector<std::int64_t> sizes(comm_.World(), -1);
sizes[comm_.Rank()] = s_data.size_bytes();
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()});
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
// create result
dh::device_vector<std::int32_t> result(comm_.World(), -1);
auto s_result = common::EraseType(dh::ToSpan(result));
@ -42,7 +42,7 @@ class Worker : public NCCLWorkerForTest {
std::vector<std::int64_t> recv_seg(nccl_comm_->World() + 1, 0);
rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()},
common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
for (std::int32_t i = 0; i < comm_.World(); ++i) {
ASSERT_EQ(result[i], i);
@ -58,7 +58,7 @@ class Worker : public NCCLWorkerForTest {
std::vector<std::int64_t> sizes(nccl_comm_->World(), 0);
sizes[comm_.Rank()] = dh::ToSpan(data).size_bytes();
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()});
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
auto n_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
// create result
dh::device_vector<std::int32_t> result(n_bytes / sizeof(std::int32_t), -1);
@ -67,7 +67,7 @@ class Worker : public NCCLWorkerForTest {
std::vector<std::int64_t> recv_seg(nccl_comm_->World() + 1, 0);
rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()},
common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
// check segment size
if (algo != AllgatherVAlgo::kBcast) {
auto size = recv_seg[nccl_comm_->Rank() + 1] - recv_seg[nccl_comm_->Rank()];

View File

@ -59,7 +59,7 @@ class AllreduceWorker : public WorkerForTest {
auto pcoll = std::shared_ptr<Coll>{new Coll{}};
auto rc = pcoll->Allreduce(comm_, common::EraseType(common::Span{data.data(), data.size()}),
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
for (auto v : data) {
ASSERT_EQ(v, ~std::uint32_t{0});
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#if defined(XGBOOST_USE_NCCL)
#include <gtest/gtest.h>
@ -24,7 +24,7 @@ class Worker : public NCCLWorkerForTest {
data[comm_.Rank()] = ~std::uint32_t{0};
auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)),
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
thrust::host_vector<std::uint32_t> h_data(data.size());
thrust::copy(data.cbegin(), data.cend(), h_data.begin());
for (auto v : h_data) {
@ -36,7 +36,7 @@ class Worker : public NCCLWorkerForTest {
dh::device_vector<double> data(314, 1.5);
auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)),
ArrayInterfaceHandler::kF8, Op::kSum);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
for (std::size_t i = 0; i < data.size(); ++i) {
auto v = data[i];
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/collective/socket.h>
@ -10,7 +10,6 @@
#include <vector> // for vector
#include "../../../src/collective/broadcast.h" // for Broadcast
#include "../../../src/collective/tracker.h" // for GetHostAddress
#include "test_worker.h" // for WorkerForTest, TestDistributed
namespace xgboost::collective {
@ -24,14 +23,14 @@ class Worker : public WorkerForTest {
// basic test
std::vector<std::int32_t> data(1, comm_.Rank());
auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
ASSERT_EQ(data[0], r);
}
for (std::int32_t r = 0; r < comm_.World(); ++r) {
std::vector<std::int32_t> data(1 << 16, comm_.Rank());
auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
ASSERT_EQ(data[0], r);
}
}
@ -41,11 +40,11 @@ class BroadcastTest : public SocketTest {};
} // namespace
TEST_F(BroadcastTest, Basic) {
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
std::int32_t n_workers = std::min(2u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Worker worker{host, port, timeout, n_workers, r};
worker.Run();
});
} // namespace
}
} // namespace xgboost::collective

View File

@ -1,14 +1,16 @@
/*!
* Copyright 2017-2021 XGBoost contributors
/**
* Copyright 2017-2024, XGBoost contributors
*/
#include <thrust/device_vector.h>
#include <thrust/sort.h> // for is_sorted
#include <xgboost/base.h>
#include <cstddef>
#include <cstdint>
#include <thrust/device_vector.h>
#include <vector>
#include <xgboost/base.h>
#include "../../../src/common/device_helpers.cuh"
#include "../../../src/common/quantile.h"
#include "../helpers.h"
#include "gtest/gtest.h"
TEST(SumReduce, Test) {

View File

@ -1,10 +1,11 @@
/**
* Copyright 2019-2023, XGBoost Contributors
* Copyright 2019-2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <cstddef> // for size_t
#include <fstream> // for ofstream
#include <numeric> // for iota
#include "../../../src/common/io.h"
#include "../filesystem.h" // dmlc::TemporaryDirectory

View File

@ -4,10 +4,10 @@
#include <gtest/gtest.h>
#include <fstream>
#include <iterator> // for back_inserter
#include <limits> // for numeric_limits
#include <map>
#include <numeric> // for iota
#include "../../../src/common/charconv.h"
#include "../../../src/common/io.h"
#include "../../../src/common/json_utils.h"
#include "../../../src/common/threading_utils.h" // for ParallelFor