xgboost/tests/cpp/collective/test_tracker.cc
Jiaming Yuan 3fbb221fec
[coll] Implement shutdown for tracker and comm. (#10208)
- Force shutdown the tracker.
- Implement shutdown notice for error handling thread in comm.
2024-04-20 04:08:17 +08:00

121 lines
3.1 KiB
C++

/**
* Copyright 2023-2024, XGBoost Contributors
*/
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <chrono> // for seconds
#include <cstdint> // for int32_t
#include <string> // for string
#include <thread> // for thread
#include <vector> // for vector
#include "../../../src/collective/comm.h"
#include "../helpers.h" // for GMockThrow
#include "test_worker.h"
namespace xgboost::collective {
namespace {
class PrintWorker : public WorkerForTest {
public:
using WorkerForTest::WorkerForTest;
void Print() {
auto rc = comm_.LogTracker("ack:" + std::to_string(this->comm_.Rank()));
SafeColl(rc);
}
};
} // namespace
TEST_F(TrackerTest, Bootstrap) {
RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
ASSERT_FALSE(tracker.Ready());
auto fut = tracker.Run();
std::vector<std::thread> workers;
auto args = tracker.WorkerArgs();
ASSERT_TRUE(tracker.Ready());
ASSERT_EQ(get<String const>(args["dmlc_tracker_uri"]), host);
std::int32_t port = tracker.Port();
for (std::int32_t i = 0; i < n_workers; ++i) {
workers.emplace_back([=] { WorkerForTest worker{host, port, timeout, n_workers, i}; });
}
for (auto &w : workers) {
w.join();
}
SafeColl(fut.get());
}
TEST_F(TrackerTest, Print) {
RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
auto fut = tracker.Run();
std::vector<std::thread> workers;
auto rc = tracker.WaitUntilReady();
ASSERT_TRUE(rc.OK());
std::int32_t port = tracker.Port();
for (std::int32_t i = 0; i < n_workers; ++i) {
workers.emplace_back([=] {
PrintWorker worker{host, port, timeout, n_workers, i};
worker.Print();
});
}
for (auto &w : workers) {
w.join();
}
ASSERT_TRUE(fut.get().OK());
}
TEST_F(TrackerTest, GetHostAddress) { ASSERT_TRUE(host.find("127.") == std::string::npos); }
/**
* Test connecting the tracker after it has finished. This should not hang the workers.
*/
TEST_F(TrackerTest, AfterShutdown) {
RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
auto fut = tracker.Run();
std::vector<std::thread> workers;
auto rc = tracker.WaitUntilReady();
ASSERT_TRUE(rc.OK());
std::int32_t port = tracker.Port();
// Launch no-op workers to cause the tracker to shutdown.
for (std::int32_t i = 0; i < n_workers; ++i) {
workers.emplace_back([=] { WorkerForTest worker{host, port, timeout, n_workers, i}; });
}
for (auto &w : workers) {
w.join();
}
ASSERT_TRUE(fut.get().OK());
// Launch workers again, they should fail.
workers.clear();
for (std::int32_t i = 0; i < n_workers; ++i) {
auto assert_that = [=] {
WorkerForTest worker{host, port, timeout, n_workers, i};
};
// On a Linux platform, the connection will be refused, on Apple platform, this gets
// an operation now in progress poll failure, on Windows, it's a timeout error.
#if defined(__linux__)
workers.emplace_back([=] { ASSERT_THAT(assert_that, GMockThrow("Connection refused")); });
#else
workers.emplace_back([=] { ASSERT_THAT(assert_that, GMockThrow("Failed to connect to")); });
#endif
}
for (auto &w : workers) {
w.join();
}
}
} // namespace xgboost::collective