[coll] Allreduce. (#9679)
This commit is contained in:
72
tests/cpp/collective/test_allreduce.cc
Normal file
72
tests/cpp/collective/test_allreduce.cc
Normal file
@@ -0,0 +1,72 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "../../../src/collective/allreduce.h"
|
||||
#include "../../../src/collective/tracker.h"
|
||||
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
||||
|
||||
namespace xgboost::collective {
|
||||
|
||||
namespace {
|
||||
class AllreduceWorker : public WorkerForTest {
|
||||
public:
|
||||
using WorkerForTest::WorkerForTest;
|
||||
|
||||
void Basic() {
|
||||
{
|
||||
std::vector<double> data(13, 0.0);
|
||||
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
||||
rhs[i] += lhs[i];
|
||||
}
|
||||
});
|
||||
ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0);
|
||||
}
|
||||
{
|
||||
std::vector<double> data(1, 1.0);
|
||||
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
||||
rhs[i] += lhs[i];
|
||||
}
|
||||
});
|
||||
ASSERT_EQ(data[0], static_cast<double>(comm_.World()));
|
||||
}
|
||||
}
|
||||
|
||||
void Acc() {
|
||||
std::vector<double> data(314, 1.5);
|
||||
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
||||
rhs[i] += lhs[i];
|
||||
}
|
||||
});
|
||||
for (std::size_t i = 0; i < data.size(); ++i) {
|
||||
auto v = data[i];
|
||||
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class AllreduceTest : public SocketTest {};
|
||||
} // namespace
|
||||
|
||||
TEST_F(AllreduceTest, Basic) {
|
||||
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
AllreduceWorker worker{host, port, timeout, n_workers, r};
|
||||
worker.Basic();
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(AllreduceTest, Sum) {
|
||||
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
AllreduceWorker worker{host, port, timeout, n_workers, r};
|
||||
worker.Acc();
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@@ -10,8 +10,8 @@
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../src/collective/broadcast.h" // for Broadcast
|
||||
#include "../../../src/collective/tracker.h" // for GetHostAddress, Tracker
|
||||
#include "test_worker.h" // for WorkerForTest
|
||||
#include "../../../src/collective/tracker.h" // for GetHostAddress
|
||||
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
@@ -41,28 +41,11 @@ class BroadcastTest : public SocketTest {};
|
||||
} // namespace
|
||||
|
||||
TEST_F(BroadcastTest, Basic) {
|
||||
std::int32_t n_workers = std::min(24u, std::thread::hardware_concurrency());
|
||||
std::chrono::seconds timeout{3};
|
||||
|
||||
std::string host;
|
||||
ASSERT_TRUE(GetHostAddress(&host).OK());
|
||||
RabitTracker tracker{StringView{host}, n_workers, 0, timeout};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] {
|
||||
Worker worker{host, port, timeout, n_workers, i};
|
||||
worker.Run();
|
||||
});
|
||||
}
|
||||
|
||||
for (auto& t : workers) {
|
||||
t.join();
|
||||
}
|
||||
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
Worker worker{host, port, timeout, n_workers, r};
|
||||
worker.Run();
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -88,4 +88,27 @@ class TrackerTest : public SocketTest {
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WorkerFn>
|
||||
void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) {
|
||||
std::chrono::seconds timeout{1};
|
||||
|
||||
std::string host;
|
||||
ASSERT_TRUE(GetHostAddress(&host).OK());
|
||||
RabitTracker tracker{StringView{host}, n_workers, 0, timeout};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] { worker_fn(host, port, timeout, i); });
|
||||
}
|
||||
|
||||
for (auto& t : workers) {
|
||||
t.join();
|
||||
}
|
||||
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
Reference in New Issue
Block a user