Convert federated learner test into test suite. (#9018)

* Convert federated learner test into test suite.

- Add specialization to learning to rank.
This commit is contained in:
Jiaming Yuan
2023-04-11 09:52:55 +08:00
committed by GitHub
parent 2c8d735cb3
commit fe9dff339c
8 changed files with 152 additions and 104 deletions

View File

@@ -8,6 +8,7 @@
#include <xgboost/json.h>
#include <random>
#include <thread> // for thread, sleep_for
#include "../../../plugin/federated/federated_server.h"
#include "../../../src/collective/communicator-inl.h"
@@ -33,13 +34,17 @@ inline std::string GetServerAddress() {
namespace xgboost {
class BaseFederatedTest : public ::testing::Test {
protected:
void SetUp() override {
class ServerForTest {
std::string server_address_;
std::unique_ptr<std::thread> server_thread_;
std::unique_ptr<grpc::Server> server_;
public:
explicit ServerForTest(std::int32_t world_size) {
server_address_ = GetServerAddress();
server_thread_.reset(new std::thread([this] {
server_thread_.reset(new std::thread([this, world_size] {
grpc::ServerBuilder builder;
xgboost::federated::FederatedService service{kWorldSize};
xgboost::federated::FederatedService service{world_size};
builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
@@ -47,15 +52,21 @@ class BaseFederatedTest : public ::testing::Test {
}));
}
void TearDown() override {
~ServerForTest() {
server_->Shutdown();
server_thread_->join();
}
auto Address() const { return server_address_; }
};
class BaseFederatedTest : public ::testing::Test {
protected:
void SetUp() override { server_ = std::make_unique<ServerForTest>(kWorldSize); }
void TearDown() override { server_.reset(nullptr); }
static int const kWorldSize{3};
std::string server_address_;
std::unique_ptr<std::thread> server_thread_;
std::unique_ptr<grpc::Server> server_;
std::unique_ptr<ServerForTest> server_;
};
template <typename Function, typename... Args>