Convert federated learner test into test suite. (#9018)

* Convert federated learner test into test suite.

- Add specialization to learning to rank.
This commit is contained in:
Jiaming Yuan
2023-04-11 09:52:55 +08:00
committed by GitHub
parent 2c8d735cb3
commit fe9dff339c
8 changed files with 152 additions and 104 deletions

View File

@@ -92,7 +92,7 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) {
config["federated_server_address"] = server_address;
config["federated_world_size"] = std::string("1");
config["federated_rank"] = Integer(0);
auto *comm = FederatedCommunicator::Create(config);
FederatedCommunicator::Create(config);
};
EXPECT_THROW(construct(), dmlc::Error);
}
@@ -104,7 +104,7 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) {
config["federated_server_address"] = server_address;
config["federated_world_size"] = 1;
config["federated_rank"] = std::string("0");
auto *comm = FederatedCommunicator::Create(config);
FederatedCommunicator::Create(config);
};
EXPECT_THROW(construct(), dmlc::Error);
}
@@ -125,7 +125,7 @@ TEST(FederatedCommunicatorSimpleTest, IsDistributed) {
TEST_F(FederatedCommunicatorTest, Allgather) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgather, rank, server_address_);
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgather, rank, server_->Address());
}
for (auto &thread : threads) {
thread.join();
@@ -135,7 +135,7 @@ TEST_F(FederatedCommunicatorTest, Allgather) {
TEST_F(FederatedCommunicatorTest, Allreduce) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_address_);
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_->Address());
}
for (auto &thread : threads) {
thread.join();
@@ -145,7 +145,7 @@ TEST_F(FederatedCommunicatorTest, Allreduce) {
TEST_F(FederatedCommunicatorTest, Broadcast) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_address_);
threads.emplace_back(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_->Address());
}
for (auto &thread : threads) {
thread.join();