[coll] Add global functions. (#10203)

This commit is contained in:
Jiaming Yuan
2024-04-19 03:17:23 +08:00
committed by GitHub
parent 551fa6e25e
commit 3f64b4fde3
21 changed files with 283 additions and 69 deletions

View File

@@ -71,7 +71,7 @@ target_include_directories(testxgboost
${xgboost_SOURCE_DIR}/rabit/include)
target_link_libraries(testxgboost
PRIVATE
${GTEST_LIBRARIES})
GTest::gtest GTest::gmock)
set_output_directory(testxgboost ${xgboost_BINARY_DIR})

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include <gtest/gtest.h> // for ASSERT_EQ
#include <xgboost/span.h> // for Span, oper...
@@ -35,7 +35,7 @@ class Worker : public WorkerForTest {
data[comm_.Rank()] = comm_.Rank();
auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()});
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
for (std::int32_t r = 0; r < comm_.World(); ++r) {
ASSERT_EQ(data[r], r);
@@ -52,7 +52,7 @@ class Worker : public WorkerForTest {
std::iota(seg.begin(), seg.end(), comm_.Rank());
auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()});
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
for (std::int32_t r = 0; r < comm_.World(); ++r) {
auto seg = s_data.subspan(r * n, n);
@@ -81,7 +81,7 @@ class Worker : public WorkerForTest {
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
std::vector<std::int32_t> result;
auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2);
CheckV(result);
}
@@ -91,7 +91,7 @@ class Worker : public WorkerForTest {
std::int32_t n{comm_.Rank()};
std::vector<std::int32_t> result;
auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
for (std::int32_t i = 0; i < comm_.World(); ++i) {
ASSERT_EQ(result[i], i);
}
@@ -105,7 +105,7 @@ class Worker : public WorkerForTest {
std::vector<std::int64_t> sizes(comm_.World(), 0);
sizes[comm_.Rank()] = s_data.size_bytes();
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()});
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
std::shared_ptr<Coll> pcoll{new Coll{}};
std::vector<std::int64_t> recv_segments(comm_.World() + 1, 0);

View File

@@ -34,7 +34,7 @@ class Worker : public NCCLWorkerForTest {
std::vector<std::int64_t> sizes(comm_.World(), -1);
sizes[comm_.Rank()] = s_data.size_bytes();
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()});
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
// create result
dh::device_vector<std::int32_t> result(comm_.World(), -1);
auto s_result = common::EraseType(dh::ToSpan(result));
@@ -42,7 +42,7 @@ class Worker : public NCCLWorkerForTest {
std::vector<std::int64_t> recv_seg(nccl_comm_->World() + 1, 0);
rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()},
common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
for (std::int32_t i = 0; i < comm_.World(); ++i) {
ASSERT_EQ(result[i], i);
@@ -58,7 +58,7 @@ class Worker : public NCCLWorkerForTest {
std::vector<std::int64_t> sizes(nccl_comm_->World(), 0);
sizes[comm_.Rank()] = dh::ToSpan(data).size_bytes();
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()});
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
auto n_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
// create result
dh::device_vector<std::int32_t> result(n_bytes / sizeof(std::int32_t), -1);
@@ -67,7 +67,7 @@ class Worker : public NCCLWorkerForTest {
std::vector<std::int64_t> recv_seg(nccl_comm_->World() + 1, 0);
rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()},
common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
// check segment size
if (algo != AllgatherVAlgo::kBcast) {
auto size = recv_seg[nccl_comm_->Rank() + 1] - recv_seg[nccl_comm_->Rank()];

View File

@@ -59,7 +59,7 @@ class AllreduceWorker : public WorkerForTest {
auto pcoll = std::shared_ptr<Coll>{new Coll{}};
auto rc = pcoll->Allreduce(comm_, common::EraseType(common::Span{data.data(), data.size()}),
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
for (auto v : data) {
ASSERT_EQ(v, ~std::uint32_t{0});
}

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#if defined(XGBOOST_USE_NCCL)
#include <gtest/gtest.h>
@@ -24,7 +24,7 @@ class Worker : public NCCLWorkerForTest {
data[comm_.Rank()] = ~std::uint32_t{0};
auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)),
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
thrust::host_vector<std::uint32_t> h_data(data.size());
thrust::copy(data.cbegin(), data.cend(), h_data.begin());
for (auto v : h_data) {
@@ -36,7 +36,7 @@ class Worker : public NCCLWorkerForTest {
dh::device_vector<double> data(314, 1.5);
auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)),
ArrayInterfaceHandler::kF8, Op::kSum);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
for (std::size_t i = 0; i < data.size(); ++i) {
auto v = data[i];
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/collective/socket.h>
@@ -10,7 +10,6 @@
#include <vector> // for vector
#include "../../../src/collective/broadcast.h" // for Broadcast
#include "../../../src/collective/tracker.h" // for GetHostAddress
#include "test_worker.h" // for WorkerForTest, TestDistributed
namespace xgboost::collective {
@@ -24,14 +23,14 @@ class Worker : public WorkerForTest {
// basic test
std::vector<std::int32_t> data(1, comm_.Rank());
auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
ASSERT_EQ(data[0], r);
}
for (std::int32_t r = 0; r < comm_.World(); ++r) {
std::vector<std::int32_t> data(1 << 16, comm_.Rank());
auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
ASSERT_EQ(data[0], r);
}
}
@@ -41,11 +40,11 @@ class BroadcastTest : public SocketTest {};
} // namespace
TEST_F(BroadcastTest, Basic) {
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
std::int32_t n_workers = std::min(2u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Worker worker{host, port, timeout, n_workers, r};
worker.Run();
});
} // namespace
}
} // namespace xgboost::collective

View File

@@ -1,14 +1,16 @@
/*!
* Copyright 2017-2021 XGBoost contributors
/**
* Copyright 2017-2024, XGBoost contributors
*/
#include <thrust/device_vector.h>
#include <thrust/sort.h> // for is_sorted
#include <xgboost/base.h>
#include <cstddef>
#include <cstdint>
#include <thrust/device_vector.h>
#include <vector>
#include <xgboost/base.h>
#include "../../../src/common/device_helpers.cuh"
#include "../../../src/common/quantile.h"
#include "../helpers.h"
#include "gtest/gtest.h"
TEST(SumReduce, Test) {

View File

@@ -1,10 +1,11 @@
/**
* Copyright 2019-2023, XGBoost Contributors
* Copyright 2019-2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <cstddef> // for size_t
#include <fstream> // for ofstream
#include <numeric> // for iota
#include "../../../src/common/io.h"
#include "../filesystem.h" // dmlc::TemporaryDirectory

View File

@@ -4,10 +4,10 @@
#include <gtest/gtest.h>
#include <fstream>
#include <iterator> // for back_inserter
#include <limits> // for numeric_limits
#include <map>
#include <numeric> // for iota
#include "../../../src/common/charconv.h"
#include "../../../src/common/io.h"
#include "../../../src/common/json_utils.h"
#include "../../../src/common/threading_utils.h" // for ParallelFor