sync up May15 2023
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <dmlc/omp.h>
|
||||
#include <grpcpp/server_builder.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/json.h>
|
||||
@@ -61,24 +62,33 @@ class BaseFederatedTest : public ::testing::Test {
|
||||
template <typename Function, typename... Args>
|
||||
void RunWithFederatedCommunicator(int32_t world_size, std::string const& server_address,
|
||||
Function&& function, Args&&... args) {
|
||||
auto run = [&](auto rank) {
|
||||
Json config{JsonObject()};
|
||||
config["xgboost_communicator"] = String("federated");
|
||||
config["federated_server_address"] = String(server_address);
|
||||
config["federated_world_size"] = world_size;
|
||||
config["federated_rank"] = rank;
|
||||
xgboost::collective::Init(config);
|
||||
|
||||
std::forward<Function>(function)(std::forward<Args>(args)...);
|
||||
|
||||
xgboost::collective::Finalize();
|
||||
};
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel num_threads(world_size)
|
||||
{
|
||||
auto rank = omp_get_thread_num();
|
||||
run(rank);
|
||||
}
|
||||
#else
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < world_size; rank++) {
|
||||
threads.emplace_back([&, rank]() {
|
||||
Json config{JsonObject()};
|
||||
config["xgboost_communicator"] = String("federated");
|
||||
config["federated_server_address"] = String(server_address);
|
||||
config["federated_world_size"] = world_size;
|
||||
config["federated_rank"] = rank;
|
||||
xgboost::collective::Init(config);
|
||||
|
||||
std::forward<Function>(function)(std::forward<Args>(args)...);
|
||||
|
||||
xgboost::collective::Finalize();
|
||||
});
|
||||
threads.emplace_back(run, rank);
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user