[coll] Implement a new tracker and a communicator. (#9650)
* [coll] Implement a new tracker and a communicator. The new tracker and communicators communicate through the use of JSON documents. Along with which, communicators are aware of each other.
This commit is contained in:
47
tests/cpp/collective/test_comm.cc
Normal file
47
tests/cpp/collective/test_comm.cc
Normal file
@@ -0,0 +1,47 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "../../../src/collective/comm.h"
|
||||
#include "test_worker.h"
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class CommTest : public TrackerTest {};
|
||||
} // namespace
|
||||
|
||||
TEST_F(CommTest, Channel) {
|
||||
auto n_workers = 4;
|
||||
RabitTracker tracker{host, n_workers, 0, timeout};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] {
|
||||
WorkerForTest worker{host, port, timeout, n_workers, i};
|
||||
if (i % 2 == 0) {
|
||||
auto p_chan = worker.Comm().Chan(i + 1);
|
||||
p_chan->SendAll(
|
||||
EraseType(common::Span<std::int32_t const>{&i, static_cast<std::size_t>(1)}));
|
||||
auto rc = p_chan->Block();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
} else {
|
||||
auto p_chan = worker.Comm().Chan(i - 1);
|
||||
std::int32_t r{-1};
|
||||
p_chan->RecvAll(EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
|
||||
auto rc = p_chan->Block();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
ASSERT_EQ(r, i - 1);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for (auto &w : workers) {
|
||||
w.join();
|
||||
}
|
||||
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@@ -7,7 +7,7 @@
|
||||
#include <cerrno> // EADDRNOTAVAIL
|
||||
#include <system_error> // std::error_code, std::system_category
|
||||
|
||||
#include "net_test.h" // for SocketTest
|
||||
#include "test_worker.h" // for SocketTest
|
||||
|
||||
namespace xgboost::collective {
|
||||
TEST_F(SocketTest, Basic) {
|
||||
|
||||
@@ -1,18 +1,67 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include "../../../src/collective/tracker.h" // for GetHostAddress
|
||||
#include "net_test.h" // for SocketTest
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <chrono> // for seconds
|
||||
#include <cstdint> // for int32_t
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../src/collective/comm.h"
|
||||
#include "test_worker.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class TrackerTest : public SocketTest {};
|
||||
class PrintWorker : public WorkerForTest {
|
||||
public:
|
||||
using WorkerForTest::WorkerForTest;
|
||||
|
||||
void Print() {
|
||||
auto rc = comm_.LogTracker("ack:" + std::to_string(this->comm_.Rank()));
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST_F(TrackerTest, GetHostAddress) {
|
||||
std::string host;
|
||||
auto rc = GetHostAddress(&host);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
ASSERT_TRUE(host.find("127.") == std::string::npos);
|
||||
TEST_F(TrackerTest, Bootstrap) {
|
||||
RabitTracker tracker{host, n_workers, 0, timeout};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] { WorkerForTest worker{host, port, timeout, n_workers, i}; });
|
||||
}
|
||||
for (auto &w : workers) {
|
||||
w.join();
|
||||
}
|
||||
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
}
|
||||
|
||||
TEST_F(TrackerTest, Print) {
|
||||
RabitTracker tracker{host, n_workers, 0, timeout};
|
||||
auto fut = tracker.Run();
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] {
|
||||
PrintWorker worker{host, port, timeout, n_workers, i};
|
||||
worker.Print();
|
||||
});
|
||||
}
|
||||
|
||||
for (auto &w : workers) {
|
||||
w.join();
|
||||
}
|
||||
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
}
|
||||
|
||||
TEST_F(TrackerTest, GetHostAddress) { ASSERT_TRUE(host.find("127.") == std::string::npos); }
|
||||
} // namespace xgboost::collective
|
||||
|
||||
91
tests/cpp/collective/test_worker.h
Normal file
91
tests/cpp/collective/test_worker.h
Normal file
@@ -0,0 +1,91 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <chrono> // for seconds
|
||||
#include <cstdint> // for int32_t
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../src/collective/comm.h"
|
||||
#include "../../../src/collective/tracker.h" // for GetHostAddress
|
||||
#include "../helpers.h" // for FileExists
|
||||
|
||||
namespace xgboost::collective {
|
||||
class WorkerForTest {
|
||||
std::string tracker_host_;
|
||||
std::int32_t tracker_port_;
|
||||
std::int32_t world_size_;
|
||||
|
||||
protected:
|
||||
std::int32_t retry_{1};
|
||||
std::string task_id_;
|
||||
RabitComm comm_;
|
||||
|
||||
public:
|
||||
WorkerForTest(std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t world, std::int32_t rank)
|
||||
: tracker_host_{std::move(host)},
|
||||
tracker_port_{port},
|
||||
world_size_{world},
|
||||
task_id_{"t:" + std::to_string(rank)},
|
||||
comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_} {
|
||||
CHECK_EQ(world_size_, comm_.World());
|
||||
}
|
||||
virtual ~WorkerForTest() = default;
|
||||
auto& Comm() { return comm_; }
|
||||
|
||||
void LimitSockBuf(std::int32_t n_bytes) {
|
||||
for (std::int32_t i = 0; i < comm_.World(); ++i) {
|
||||
if (i != comm_.Rank()) {
|
||||
ASSERT_TRUE(comm_.Chan(i)->Socket()->NonBlocking());
|
||||
ASSERT_TRUE(comm_.Chan(i)->Socket()->SetBufSize(n_bytes).OK());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class SocketTest : public ::testing::Test {
|
||||
protected:
|
||||
std::string skip_msg_{"Skipping IPv6 test"};
|
||||
|
||||
bool SkipTest() {
|
||||
std::string path{"/sys/module/ipv6/parameters/disable"};
|
||||
if (FileExists(path)) {
|
||||
std::ifstream fin(path);
|
||||
if (!fin) {
|
||||
return true;
|
||||
}
|
||||
std::string s_value;
|
||||
fin >> s_value;
|
||||
auto value = std::stoi(s_value);
|
||||
if (value != 0) {
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override { system::SocketStartup(); }
|
||||
void TearDown() override { system::SocketFinalize(); }
|
||||
};
|
||||
|
||||
class TrackerTest : public SocketTest {
|
||||
public:
|
||||
std::int32_t n_workers{2};
|
||||
std::chrono::seconds timeout{1};
|
||||
std::string host;
|
||||
|
||||
void SetUp() override {
|
||||
SocketTest::SetUp();
|
||||
auto rc = GetHostAddress(&host);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
Reference in New Issue
Block a user