Fix multi-threaded gtests (#9148)
This commit is contained in:
@@ -497,23 +497,32 @@ inline std::int32_t AllThreadsForTest() { return Context{}.Threads(); }
|
||||
|
||||
template <typename Function, typename... Args>
|
||||
void RunWithInMemoryCommunicator(int32_t world_size, Function&& function, Args&&... args) {
|
||||
auto run = [&](auto rank) {
|
||||
Json config{JsonObject()};
|
||||
config["xgboost_communicator"] = String("in-memory");
|
||||
config["in_memory_world_size"] = world_size;
|
||||
config["in_memory_rank"] = rank;
|
||||
xgboost::collective::Init(config);
|
||||
|
||||
std::forward<Function>(function)(std::forward<Args>(args)...);
|
||||
|
||||
xgboost::collective::Finalize();
|
||||
};
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel num_threads(world_size)
|
||||
{
|
||||
auto rank = omp_get_thread_num();
|
||||
run(rank);
|
||||
}
|
||||
#else
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < world_size; rank++) {
|
||||
threads.emplace_back([&, rank]() {
|
||||
Json config{JsonObject()};
|
||||
config["xgboost_communicator"] = String("in-memory");
|
||||
config["in_memory_world_size"] = world_size;
|
||||
config["in_memory_rank"] = rank;
|
||||
xgboost::collective::Init(config);
|
||||
|
||||
std::forward<Function>(function)(std::forward<Args>(args)...);
|
||||
|
||||
xgboost::collective::Finalize();
|
||||
});
|
||||
threads.emplace_back(run, rank);
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
class DeclareUnifiedDistributedTest(MetricTest) : public ::testing::Test {
|
||||
|
||||
Reference in New Issue
Block a user