[coll] allgather. (#9681)
This commit is contained in:
parent
48ac9b6cbe
commit
4c0e4422d0
@ -20,4 +20,24 @@ namespace cpu_impl {
|
|||||||
std::shared_ptr<Channel> prev_ch,
|
std::shared_ptr<Channel> prev_ch,
|
||||||
std::shared_ptr<Channel> next_ch);
|
std::shared_ptr<Channel> next_ch);
|
||||||
} // namespace cpu_impl
|
} // namespace cpu_impl
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
|
||||||
|
auto n_total_bytes = data.size_bytes();
|
||||||
|
auto n_bytes = sizeof(T) * size;
|
||||||
|
auto erased =
|
||||||
|
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes};
|
||||||
|
|
||||||
|
auto rank = comm.Rank();
|
||||||
|
auto prev = BootstrapPrev(rank, comm.World());
|
||||||
|
auto next = BootstrapNext(rank, comm.World());
|
||||||
|
|
||||||
|
auto prev_ch = comm.Chan(prev);
|
||||||
|
auto next_ch = comm.Chan(next);
|
||||||
|
auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes, 0, prev_ch, next_ch);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
return comm.Block();
|
||||||
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
71
tests/cpp/collective/test_allgather.cc
Normal file
71
tests/cpp/collective/test_allgather.cc
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/span.h> // for Span
|
||||||
|
|
||||||
|
#include <cstdint> // for int32_t
|
||||||
|
#include <numeric> // for iota
|
||||||
|
#include <string> // for string
|
||||||
|
#include <thread> // for thread
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
|
#include "../../../src/collective/allgather.h"
|
||||||
|
#include "../../../src/collective/tracker.h" // for GetHostAddress, Tracker
|
||||||
|
#include "test_worker.h" // for TestDistributed==
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
namespace {
|
||||||
|
class AllgatherTest : public TrackerTest {};
|
||||||
|
|
||||||
|
class Worker : public WorkerForTest {
|
||||||
|
public:
|
||||||
|
using WorkerForTest::WorkerForTest;
|
||||||
|
|
||||||
|
void Run() {
|
||||||
|
{
|
||||||
|
// basic test
|
||||||
|
std::vector<std::int32_t> data(comm_.World(), 0);
|
||||||
|
data[comm_.Rank()] = comm_.Rank();
|
||||||
|
|
||||||
|
auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()}, 1);
|
||||||
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
|
|
||||||
|
for (std::int32_t r = 0; r < comm_.World(); ++r) {
|
||||||
|
ASSERT_EQ(data[r], r);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// test for limited socket buffer
|
||||||
|
this->LimitSockBuf(4096);
|
||||||
|
|
||||||
|
std::size_t n = 8192; // n_bytes = 8192 * sizeof(int)
|
||||||
|
std::vector<std::int32_t> data(comm_.World() * n, 0);
|
||||||
|
auto s_data = common::Span{data.data(), data.size()};
|
||||||
|
auto seg = s_data.subspan(comm_.Rank() * n, n);
|
||||||
|
std::iota(seg.begin(), seg.end(), comm_.Rank());
|
||||||
|
|
||||||
|
auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()}, n);
|
||||||
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
|
|
||||||
|
for (std::int32_t r = 0; r < comm_.World(); ++r) {
|
||||||
|
auto seg = s_data.subspan(r * n, n);
|
||||||
|
for (std::int32_t i = 0; i < static_cast<std::int32_t>(seg.size()); ++i) {
|
||||||
|
auto v = seg[i];
|
||||||
|
ASSERT_EQ(v, r + i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST_F(AllgatherTest, Basic) {
|
||||||
|
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||||
|
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||||
|
std::int32_t r) {
|
||||||
|
Worker worker{host, port, timeout, n_workers, r};
|
||||||
|
worker.Run();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
Loading…
x
Reference in New Issue
Block a user