From 4c0e4422d0e2115928fd68e29f7c1c67cc969854 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 18 Oct 2023 10:22:18 +0800 Subject: [PATCH] [coll] allgather. (#9681) --- src/collective/allgather.h | 20 ++++++++ tests/cpp/collective/test_allgather.cc | 71 ++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 tests/cpp/collective/test_allgather.cc diff --git a/src/collective/allgather.h b/src/collective/allgather.h index 31a9a36b3..5dcb4ebdd 100644 --- a/src/collective/allgather.h +++ b/src/collective/allgather.h @@ -20,4 +20,24 @@ namespace cpu_impl { std::shared_ptr prev_ch, std::shared_ptr next_ch); } // namespace cpu_impl + +template +[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span data, std::size_t size) { + auto n_total_bytes = data.size_bytes(); + auto n_bytes = sizeof(T) * size; + auto erased = + common::Span{reinterpret_cast(data.data()), n_total_bytes}; + + auto rank = comm.Rank(); + auto prev = BootstrapPrev(rank, comm.World()); + auto next = BootstrapNext(rank, comm.World()); + + auto prev_ch = comm.Chan(prev); + auto next_ch = comm.Chan(next); + auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes, 0, prev_ch, next_ch); + if (!rc.OK()) { + return rc; + } + return comm.Block(); +} } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_allgather.cc b/tests/cpp/collective/test_allgather.cc new file mode 100644 index 000000000..49ba591d0 --- /dev/null +++ b/tests/cpp/collective/test_allgather.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include +#include // for Span + +#include // for int32_t +#include // for iota +#include // for string +#include // for thread +#include // for vector + +#include "../../../src/collective/allgather.h" +#include "../../../src/collective/tracker.h" // for GetHostAddress, Tracker +#include "test_worker.h" // for TestDistributed== + +namespace xgboost::collective { +namespace { +class AllgatherTest : public TrackerTest {}; + +class Worker : public WorkerForTest { + public: + using WorkerForTest::WorkerForTest; + + void Run() { + { + // basic test + std::vector data(comm_.World(), 0); + data[comm_.Rank()] = comm_.Rank(); + + auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()}, 1); + ASSERT_TRUE(rc.OK()) << rc.Report(); + + for (std::int32_t r = 0; r < comm_.World(); ++r) { + ASSERT_EQ(data[r], r); + } + } + { + // test for limited socket buffer + this->LimitSockBuf(4096); + + std::size_t n = 8192; // n_bytes = 8192 * sizeof(int) + std::vector data(comm_.World() * n, 0); + auto s_data = common::Span{data.data(), data.size()}; + auto seg = s_data.subspan(comm_.Rank() * n, n); + std::iota(seg.begin(), seg.end(), comm_.Rank()); + + auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()}, n); + ASSERT_TRUE(rc.OK()) << rc.Report(); + + for (std::int32_t r = 0; r < comm_.World(); ++r) { + auto seg = s_data.subspan(r * n, n); + for (std::int32_t i = 0; i < static_cast(seg.size()); ++i) { + auto v = seg[i]; + ASSERT_EQ(v, r + i); + } + } + } + } +}; +} // namespace + +TEST_F(AllgatherTest, Basic) { + std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency()); + TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + Worker worker{host, port, timeout, n_workers, r}; + worker.Run(); + }); +} +} // namespace xgboost::collective