Fix multi-threaded gtests (#9148)

This commit is contained in:
Rong Ou
2023-05-10 04:15:32 -07:00
committed by GitHub
parent e4129ed6ee
commit 52311dcec9
6 changed files with 120 additions and 50 deletions

View File

@@ -497,23 +497,32 @@ inline std::int32_t AllThreadsForTest() { return Context{}.Threads(); }
template <typename Function, typename... Args>
void RunWithInMemoryCommunicator(int32_t world_size, Function&& function, Args&&... args) {
auto run = [&](auto rank) {
Json config{JsonObject()};
config["xgboost_communicator"] = String("in-memory");
config["in_memory_world_size"] = world_size;
config["in_memory_rank"] = rank;
xgboost::collective::Init(config);
std::forward<Function>(function)(std::forward<Args>(args)...);
xgboost::collective::Finalize();
};
#if defined(_OPENMP)
#pragma omp parallel num_threads(world_size)
{
auto rank = omp_get_thread_num();
run(rank);
}
#else
std::vector<std::thread> threads;
for (auto rank = 0; rank < world_size; rank++) {
threads.emplace_back([&, rank]() {
Json config{JsonObject()};
config["xgboost_communicator"] = String("in-memory");
config["in_memory_world_size"] = world_size;
config["in_memory_rank"] = rank;
xgboost::collective::Init(config);
std::forward<Function>(function)(std::forward<Args>(args)...);
xgboost::collective::Finalize();
});
threads.emplace_back(run, rank);
}
for (auto& thread : threads) {
thread.join();
}
#endif
}
class DeclareUnifiedDistributedTest(MetricTest) : public ::testing::Test {