[coll] Implement shutdown for tracker and comm. (#10208)

- Force shutdown the tracker.
- Implement shutdown notice for error handling thread in comm.
This commit is contained in:
Jiaming Yuan
2024-04-20 04:08:17 +08:00
committed by GitHub
parent 8fb05c8c95
commit 3fbb221fec
24 changed files with 553 additions and 199 deletions

View File

@@ -25,13 +25,13 @@ TEST_F(TrackerAPITest, CAPI) {
auto config_str = Json::Dump(config);
auto rc = XGTrackerCreate(config_str.c_str(), &handle);
ASSERT_EQ(rc, 0);
rc = XGTrackerRun(handle);
rc = XGTrackerRun(handle, nullptr);
ASSERT_EQ(rc, 0);
std::thread bg_wait{[&] {
Json config{Object{}};
auto config_str = Json::Dump(config);
auto rc = XGTrackerWait(handle, config_str.c_str());
auto rc = XGTrackerWaitFor(handle, config_str.c_str());
ASSERT_EQ(rc, 0);
}};
@@ -42,8 +42,8 @@ TEST_F(TrackerAPITest, CAPI) {
std::string host;
ASSERT_TRUE(GetHostAddress(&host).OK());
ASSERT_EQ(host, get<String const>(args["DMLC_TRACKER_URI"]));
auto port = get<Integer const>(args["DMLC_TRACKER_PORT"]);
ASSERT_EQ(host, get<String const>(args["dmlc_tracker_uri"]));
auto port = get<Integer const>(args["dmlc_tracker_port"]);
ASSERT_NE(port, 0);
std::vector<std::thread> workers;

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
@@ -14,7 +14,7 @@ class CommTest : public TrackerTest {};
TEST_F(CommTest, Channel) {
auto n_workers = 4;
RabitTracker tracker{host, n_workers, 0, timeout};
RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
auto fut = tracker.Run();
std::vector<std::thread> workers;
@@ -29,7 +29,7 @@ TEST_F(CommTest, Channel) {
return p_chan->SendAll(
EraseType(common::Span<std::int32_t const>{&i, static_cast<std::size_t>(1)}));
} << [&] { return p_chan->Block(); };
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
} else {
auto p_chan = worker.Comm().Chan(i - 1);
std::int32_t r{-1};
@@ -37,7 +37,7 @@ TEST_F(CommTest, Channel) {
return p_chan->RecvAll(
EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
} << [&] { return p_chan->Block(); };
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
ASSERT_EQ(r, i - 1);
}
});

View File

@@ -17,17 +17,6 @@
namespace xgboost::collective {
namespace {
auto MakeConfig(std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) {
Json config{Object{}};
config["dmlc_communicator"] = std::string{"rabit"};
config["DMLC_TRACKER_URI"] = host;
config["DMLC_TRACKER_PORT"] = port;
config["dmlc_timeout_sec"] = static_cast<std::int64_t>(timeout.count());
config["DMLC_TASK_ID"] = std::to_string(r);
config["dmlc_retry"] = 2;
return config;
}
class CommGroupTest : public SocketTest {};
} // namespace
@@ -36,7 +25,7 @@ TEST_F(CommGroupTest, Basic) {
TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
Context ctx;
auto config = MakeConfig(host, port, timeout, r);
auto config = MakeDistributedTestConfig(host, port, timeout, r);
std::unique_ptr<CommGroup> ptr{CommGroup::Create(config)};
ASSERT_TRUE(ptr->IsDistributed());
ASSERT_EQ(ptr->World(), n_workers);
@@ -52,7 +41,7 @@ TEST_F(CommGroupTest, BasicGPU) {
TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
auto ctx = MakeCUDACtx(r);
auto config = MakeConfig(host, port, timeout, r);
auto config = MakeDistributedTestConfig(host, port, timeout, r);
std::unique_ptr<CommGroup> ptr{CommGroup::Create(config)};
auto const& comm = ptr->Ctx(&ctx, DeviceOrd::CUDA(0));
ASSERT_EQ(comm.TaskID(), std::to_string(r));

View File

@@ -28,13 +28,11 @@ class LoopTest : public ::testing::Test {
auto domain = SockDomain::kV4;
pair_.first = TCPSocket::Create(domain);
in_port_t port{0};
std::int32_t port{0};
auto rc = Success() << [&] {
port = pair_.first.BindHost();
return Success();
return pair_.first.BindHost(&port);
} << [&] {
pair_.first.Listen();
return Success();
return pair_.first.Listen();
};
SafeColl(rc);

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2022-2023, XGBoost Contributors
* Copyright 2022-2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/collective/socket.h>
@@ -21,14 +21,19 @@ TEST_F(SocketTest, Basic) {
auto run_test = [msg](SockDomain domain) {
auto server = TCPSocket::Create(domain);
ASSERT_EQ(server.Domain(), domain);
auto port = server.BindHost();
server.Listen();
std::int32_t port{0};
auto rc = Success() << [&] {
return server.BindHost(&port);
} << [&] {
return server.Listen();
};
SafeColl(rc);
TCPSocket client;
if (domain == SockDomain::kV4) {
auto const& addr = SockAddrV4::Loopback().Addr();
auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
} else {
auto const& addr = SockAddrV6::Loopback().Addr();
auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client);
@@ -45,7 +50,8 @@ TEST_F(SocketTest, Basic) {
accepted.Send(msg);
std::string str;
client.Recv(&str);
rc = client.Recv(&str);
SafeColl(rc);
ASSERT_EQ(StringView{str}, msg);
};

View File

@@ -1,6 +1,7 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <chrono> // for seconds
@@ -10,6 +11,7 @@
#include <vector> // for vector
#include "../../../src/collective/comm.h"
#include "../helpers.h" // for GMockThrow
#include "test_worker.h"
namespace xgboost::collective {
@@ -20,13 +22,13 @@ class PrintWorker : public WorkerForTest {
void Print() {
auto rc = comm_.LogTracker("ack:" + std::to_string(this->comm_.Rank()));
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
}
};
} // namespace
TEST_F(TrackerTest, Bootstrap) {
RabitTracker tracker{host, n_workers, 0, timeout};
RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
ASSERT_FALSE(tracker.Ready());
auto fut = tracker.Run();
@@ -34,7 +36,7 @@ TEST_F(TrackerTest, Bootstrap) {
auto args = tracker.WorkerArgs();
ASSERT_TRUE(tracker.Ready());
ASSERT_EQ(get<String const>(args["DMLC_TRACKER_URI"]), host);
ASSERT_EQ(get<String const>(args["dmlc_tracker_uri"]), host);
std::int32_t port = tracker.Port();
@@ -44,12 +46,11 @@ TEST_F(TrackerTest, Bootstrap) {
for (auto &w : workers) {
w.join();
}
ASSERT_TRUE(fut.get().OK());
SafeColl(fut.get());
}
TEST_F(TrackerTest, Print) {
RabitTracker tracker{host, n_workers, 0, timeout};
RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
auto fut = tracker.Run();
std::vector<std::thread> workers;
@@ -73,4 +74,47 @@ TEST_F(TrackerTest, Print) {
}
TEST_F(TrackerTest, GetHostAddress) { ASSERT_TRUE(host.find("127.") == std::string::npos); }
/**
* Test connecting the tracker after it has finished. This should not hang the workers.
*/
TEST_F(TrackerTest, AfterShutdown) {
RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
auto fut = tracker.Run();
std::vector<std::thread> workers;
auto rc = tracker.WaitUntilReady();
ASSERT_TRUE(rc.OK());
std::int32_t port = tracker.Port();
// Launch no-op workers to cause the tracker to shutdown.
for (std::int32_t i = 0; i < n_workers; ++i) {
workers.emplace_back([=] { WorkerForTest worker{host, port, timeout, n_workers, i}; });
}
for (auto &w : workers) {
w.join();
}
ASSERT_TRUE(fut.get().OK());
// Launch workers again, they should fail.
workers.clear();
for (std::int32_t i = 0; i < n_workers; ++i) {
auto assert_that = [=] {
WorkerForTest worker{host, port, timeout, n_workers, i};
};
// On a Linux platform, the connection will be refused, on Apple platform, this gets
// an operation now in progress poll failure, on Windows, it's a timeout error.
#if defined(__linux__)
workers.emplace_back([=] { ASSERT_THAT(assert_that, GMockThrow("Connection refused")); });
#else
workers.emplace_back([=] { ASSERT_THAT(assert_that, GMockThrow("Failed to connect to")); });
#endif
}
for (auto &w : workers) {
w.join();
}
}
} // namespace xgboost::collective

View File

@@ -37,7 +37,7 @@ class WorkerForTest {
comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_, DefaultNcclName()} {
CHECK_EQ(world_size_, comm_.World());
}
virtual ~WorkerForTest() = default;
virtual ~WorkerForTest() noexcept(false) { SafeColl(comm_.Shutdown()); }
auto& Comm() { return comm_; }
void LimitSockBuf(std::int32_t n_bytes) {
@@ -87,19 +87,30 @@ class TrackerTest : public SocketTest {
void SetUp() override {
SocketTest::SetUp();
auto rc = GetHostAddress(&host);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
}
};
inline Json MakeTrackerConfig(std::string host, std::int32_t n_workers,
std::chrono::seconds timeout) {
Json config{Object{}};
config["host"] = host;
config["port"] = Integer{0};
config["n_workers"] = Integer{n_workers};
config["sortby"] = Integer{static_cast<std::int32_t>(Tracker::SortBy::kHost)};
config["timeout"] = timeout.count();
return config;
}
template <typename WorkerFn>
void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) {
std::chrono::seconds timeout{2};
std::string host;
auto rc = GetHostAddress(&host);
ASSERT_TRUE(rc.OK()) << rc.Report();
SafeColl(rc);
LOG(INFO) << "Using " << n_workers << " workers for test.";
RabitTracker tracker{StringView{host}, n_workers, 0, timeout};
RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
auto fut = tracker.Run();
std::vector<std::thread> workers;
@@ -115,4 +126,15 @@ void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) {
ASSERT_TRUE(fut.get().OK());
}
inline auto MakeDistributedTestConfig(std::string host, std::int32_t port,
std::chrono::seconds timeout, std::int32_t r) {
Json config{Object{}};
config["dmlc_communicator"] = std::string{"rabit"};
config["dmlc_tracker_uri"] = host;
config["dmlc_tracker_port"] = port;
config["dmlc_timeout_sec"] = static_cast<std::int64_t>(timeout.count());
config["dmlc_task_id"] = std::to_string(r);
config["dmlc_retry"] = 2;
return config;
}
} // namespace xgboost::collective