Add timeout for distributed tests. (#10315)
This commit is contained in:
parent
b8a7773736
commit
d5fcbee44b
@ -191,7 +191,7 @@ Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const>
|
||||
for (std::int32_t r = 0; r < comm->World(); ++r) {
|
||||
auto as_bytes = sizes[r];
|
||||
auto rc = stub->Broadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes,
|
||||
ncclInt8, r, comm->Handle(), dh::DefaultStream());
|
||||
ncclInt8, r, comm->Handle(), comm->Stream());
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
@ -147,7 +147,8 @@ inline auto MakeDistributedTestConfig(std::string host, std::int32_t port,
|
||||
}
|
||||
|
||||
template <typename WorkerFn>
|
||||
void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need_finalize = true) {
|
||||
void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need_finalize = true,
|
||||
std::chrono::seconds test_timeout = std::chrono::seconds{30}) {
|
||||
system::SocketStartup();
|
||||
std::chrono::seconds timeout{1};
|
||||
|
||||
@ -163,12 +164,17 @@ void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] {
|
||||
auto config = MakeDistributedTestConfig(host, port, timeout, i);
|
||||
Init(config);
|
||||
worker_fn();
|
||||
if (need_finalize) {
|
||||
Finalize();
|
||||
}
|
||||
auto fut = std::async(std::launch::async, [=] {
|
||||
auto config = MakeDistributedTestConfig(host, port, timeout, i);
|
||||
Init(config);
|
||||
worker_fn();
|
||||
if (need_finalize) {
|
||||
Finalize();
|
||||
}
|
||||
});
|
||||
auto status = fut.wait_for(test_timeout);
|
||||
CHECK(status == std::future_status::ready) << "Test timeout";
|
||||
fut.get();
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user