[coll] Allreduce. (#9679)

This commit is contained in:
Jiaming Yuan
2023-10-17 13:57:14 +08:00
committed by GitHub
parent da6803b75b
commit 48ac9b6cbe
8 changed files with 301 additions and 81 deletions

View File

@@ -10,8 +10,8 @@
#include <vector> // for vector
#include "../../../src/collective/broadcast.h" // for Broadcast
#include "../../../src/collective/tracker.h" // for GetHostAddress, Tracker
#include "test_worker.h" // for WorkerForTest
#include "../../../src/collective/tracker.h" // for GetHostAddress
#include "test_worker.h" // for WorkerForTest, TestDistributed
namespace xgboost::collective {
namespace {
@@ -41,28 +41,11 @@ class BroadcastTest : public SocketTest {};
} // namespace
TEST_F(BroadcastTest, Basic) {
std::int32_t n_workers = std::min(24u, std::thread::hardware_concurrency());
std::chrono::seconds timeout{3};
std::string host;
ASSERT_TRUE(GetHostAddress(&host).OK());
RabitTracker tracker{StringView{host}, n_workers, 0, timeout};
auto fut = tracker.Run();
std::vector<std::thread> workers;
std::int32_t port = tracker.Port();
for (std::int32_t i = 0; i < n_workers; ++i) {
workers.emplace_back([=] {
Worker worker{host, port, timeout, n_workers, i};
worker.Run();
});
}
for (auto& t : workers) {
t.join();
}
ASSERT_TRUE(fut.get().OK());
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Worker worker{host, port, timeout, n_workers, r};
worker.Run();
});
}
} // namespace xgboost::collective