Add timeout for distributed tests. (#10315)

This commit is contained in:
Jiaming Yuan 2024-05-23 11:11:49 +08:00 committed by GitHub
parent b8a7773736
commit d5fcbee44b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 8 deletions

View File

@ -191,7 +191,7 @@ Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const>
for (std::int32_t r = 0; r < comm->World(); ++r) { for (std::int32_t r = 0; r < comm->World(); ++r) {
auto as_bytes = sizes[r]; auto as_bytes = sizes[r];
auto rc = stub->Broadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes, auto rc = stub->Broadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes,
ncclInt8, r, comm->Handle(), dh::DefaultStream()); ncclInt8, r, comm->Handle(), comm->Stream());
if (!rc.OK()) { if (!rc.OK()) {
return rc; return rc;
} }

View File

@ -147,7 +147,8 @@ inline auto MakeDistributedTestConfig(std::string host, std::int32_t port,
} }
template <typename WorkerFn> template <typename WorkerFn>
void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need_finalize = true) { void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need_finalize = true,
std::chrono::seconds test_timeout = std::chrono::seconds{30}) {
system::SocketStartup(); system::SocketStartup();
std::chrono::seconds timeout{1}; std::chrono::seconds timeout{1};
@ -163,12 +164,17 @@ void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need
for (std::int32_t i = 0; i < n_workers; ++i) { for (std::int32_t i = 0; i < n_workers; ++i) {
workers.emplace_back([=] { workers.emplace_back([=] {
auto config = MakeDistributedTestConfig(host, port, timeout, i); auto fut = std::async(std::launch::async, [=] {
Init(config); auto config = MakeDistributedTestConfig(host, port, timeout, i);
worker_fn(); Init(config);
if (need_finalize) { worker_fn();
Finalize(); if (need_finalize) {
} Finalize();
}
});
auto status = fut.wait_for(test_timeout);
CHECK(status == std::future_status::ready) << "Test timeout";
fut.get();
}); });
} }