Add timeout for distributed tests. (#10315)
This commit is contained in:
parent
b8a7773736
commit
d5fcbee44b
@ -191,7 +191,7 @@ Result BroadcastAllgatherV(NCCLComm const* comm, common::Span<std::int8_t const>
|
|||||||
for (std::int32_t r = 0; r < comm->World(); ++r) {
|
for (std::int32_t r = 0; r < comm->World(); ++r) {
|
||||||
auto as_bytes = sizes[r];
|
auto as_bytes = sizes[r];
|
||||||
auto rc = stub->Broadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes,
|
auto rc = stub->Broadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes,
|
||||||
ncclInt8, r, comm->Handle(), dh::DefaultStream());
|
ncclInt8, r, comm->Handle(), comm->Stream());
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -147,7 +147,8 @@ inline auto MakeDistributedTestConfig(std::string host, std::int32_t port,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename WorkerFn>
|
template <typename WorkerFn>
|
||||||
void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need_finalize = true) {
|
void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need_finalize = true,
|
||||||
|
std::chrono::seconds test_timeout = std::chrono::seconds{30}) {
|
||||||
system::SocketStartup();
|
system::SocketStartup();
|
||||||
std::chrono::seconds timeout{1};
|
std::chrono::seconds timeout{1};
|
||||||
|
|
||||||
@ -163,6 +164,7 @@ void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need
|
|||||||
|
|
||||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||||
workers.emplace_back([=] {
|
workers.emplace_back([=] {
|
||||||
|
auto fut = std::async(std::launch::async, [=] {
|
||||||
auto config = MakeDistributedTestConfig(host, port, timeout, i);
|
auto config = MakeDistributedTestConfig(host, port, timeout, i);
|
||||||
Init(config);
|
Init(config);
|
||||||
worker_fn();
|
worker_fn();
|
||||||
@ -170,6 +172,10 @@ void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need
|
|||||||
Finalize();
|
Finalize();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
auto status = fut.wait_for(test_timeout);
|
||||||
|
CHECK(status == std::future_status::ready) << "Test timeout";
|
||||||
|
fut.get();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& t : workers) {
|
for (auto& t : workers) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user