From d5fcbee44be39e6b4f16a56d26cc2282cbb3d3cc Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 23 May 2024 11:11:49 +0800 Subject: [PATCH] Add timeout for distributed tests. (#10315) --- src/collective/coll.cu | 2 +- tests/cpp/collective/test_worker.h | 20 +++++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/collective/coll.cu b/src/collective/coll.cu index b06435bfe..433f1e49d 100644 --- a/src/collective/coll.cu +++ b/src/collective/coll.cu @@ -191,7 +191,7 @@ Result BroadcastAllgatherV(NCCLComm const* comm, common::Span for (std::int32_t r = 0; r < comm->World(); ++r) { auto as_bytes = sizes[r]; auto rc = stub->Broadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes, - ncclInt8, r, comm->Handle(), dh::DefaultStream()); + ncclInt8, r, comm->Handle(), comm->Stream()); if (!rc.OK()) { return rc; } diff --git a/tests/cpp/collective/test_worker.h b/tests/cpp/collective/test_worker.h index 78f4a28d8..f1889200b 100644 --- a/tests/cpp/collective/test_worker.h +++ b/tests/cpp/collective/test_worker.h @@ -147,7 +147,8 @@ inline auto MakeDistributedTestConfig(std::string host, std::int32_t port, } template -void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need_finalize = true) { +void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need_finalize = true, + std::chrono::seconds test_timeout = std::chrono::seconds{30}) { system::SocketStartup(); std::chrono::seconds timeout{1}; @@ -163,12 +164,17 @@ void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need for (std::int32_t i = 0; i < n_workers; ++i) { workers.emplace_back([=] { - auto config = MakeDistributedTestConfig(host, port, timeout, i); - Init(config); - worker_fn(); - if (need_finalize) { - Finalize(); - } + auto fut = std::async(std::launch::async, [=] { + auto config = MakeDistributedTestConfig(host, port, timeout, i); + Init(config); + worker_fn(); + if (need_finalize) { + Finalize(); + } + }); + auto status = fut.wait_for(test_timeout); + CHECK(status == std::future_status::ready) << "Test timeout"; + fut.get(); }); }