[coll] Add nccl. (#9726)

This commit is contained in:
Jiaming Yuan
2023-10-28 16:33:58 +08:00
committed by GitHub
parent 0c621094b3
commit 6755179e77
19 changed files with 924 additions and 111 deletions

View File

@@ -14,6 +14,7 @@
#include <vector> // for vector
#include "../../../src/collective/allgather.h" // for RingAllgather
#include "../../../src/collective/coll.h" // for Coll
#include "../../../src/collective/comm.h" // for RabitComm
#include "gtest/gtest.h" // for AssertionR...
#include "test_worker.h" // for TestDistri...
@@ -63,37 +64,79 @@ class Worker : public WorkerForTest {
}
}
void TestV() {
{
// basic test
std::int32_t n{comm_.Rank()};
std::vector<std::int32_t> result;
auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result);
ASSERT_TRUE(rc.OK()) << rc.Report();
for (std::int32_t i = 0; i < comm_.World(); ++i) {
ASSERT_EQ(result[i], i);
}
}
{
// V test
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
std::vector<std::int32_t> result;
auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result);
ASSERT_TRUE(rc.OK()) << rc.Report();
ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2);
std::int32_t k{0};
for (std::int32_t r = 0; r < comm_.World(); ++r) {
auto seg = common::Span{result.data(), result.size()}.subspan(k, (r + 1));
if (comm_.Rank() == 0) {
for (auto v : seg) {
ASSERT_EQ(v, r);
}
k += seg.size();
void CheckV(common::Span<std::int32_t> result) {
std::int32_t k{0};
for (std::int32_t r = 0; r < comm_.World(); ++r) {
auto seg = common::Span{result.data(), result.size()}.subspan(k, (r + 1));
if (comm_.Rank() == 0) {
for (auto v : seg) {
ASSERT_EQ(v, r);
}
k += seg.size();
}
}
}
void TestVRing() {
// V test
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
std::vector<std::int32_t> result;
auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result);
ASSERT_TRUE(rc.OK()) << rc.Report();
ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2);
CheckV(result);
}
void TestVBasic() {
// basic test
std::int32_t n{comm_.Rank()};
std::vector<std::int32_t> result;
auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result);
ASSERT_TRUE(rc.OK()) << rc.Report();
for (std::int32_t i = 0; i < comm_.World(); ++i) {
ASSERT_EQ(result[i], i);
}
}
void TestVAlgo() {
// V test, broadcast
std::vector<std::int32_t> data(comm_.Rank() + 1, comm_.Rank());
auto s_data = common::Span{data.data(), data.size()};
std::vector<std::int64_t> sizes(comm_.World(), 0);
sizes[comm_.Rank()] = s_data.size_bytes();
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
ASSERT_TRUE(rc.OK()) << rc.Report();
std::shared_ptr<Coll> pcoll{new Coll{}};
std::vector<std::int64_t> recv_segments(comm_.World() + 1, 0);
std::vector<std::int32_t> recv(std::accumulate(sizes.cbegin(), sizes.cend(), 0));
auto s_recv = common::Span{recv.data(), recv.size()};
rc = pcoll->AllgatherV(comm_, common::EraseType(s_data),
common::Span{sizes.data(), sizes.size()},
common::Span{recv_segments.data(), recv_segments.size()},
common::EraseType(s_recv), AllgatherVAlgo::kBcast);
ASSERT_TRUE(rc.OK());
CheckV(s_recv);
// Test inplace
auto test_inplace = [&] (AllgatherVAlgo algo) {
std::fill_n(s_recv.data(), s_recv.size(), 0);
auto current = s_recv.subspan(recv_segments[comm_.Rank()],
recv_segments[comm_.Rank() + 1] - recv_segments[comm_.Rank()]);
std::copy_n(data.data(), data.size(), current.data());
rc = pcoll->AllgatherV(comm_, common::EraseType(current),
common::Span{sizes.data(), sizes.size()},
common::Span{recv_segments.data(), recv_segments.size()},
common::EraseType(s_recv), algo);
ASSERT_TRUE(rc.OK());
CheckV(s_recv);
};
test_inplace(AllgatherVAlgo::kBcast);
test_inplace(AllgatherVAlgo::kRing);
}
};
} // namespace
@@ -106,12 +149,30 @@ TEST_F(AllgatherTest, Basic) {
});
}
TEST_F(AllgatherTest, V) {
TEST_F(AllgatherTest, VBasic) {
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Worker worker{host, port, timeout, n_workers, r};
worker.TestV();
worker.TestVBasic();
});
}
TEST_F(AllgatherTest, VRing) {
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Worker worker{host, port, timeout, n_workers, r};
worker.TestVRing();
});
}
TEST_F(AllgatherTest, VAlgo) {
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Worker worker{host, port, timeout, n_workers, r};
worker.TestVAlgo();
});
}
} // namespace xgboost::collective