Convert federated learner test into test suite. (#9018)
* Convert federated learner test into test suite. - Add specialization to learning to rank.
This commit is contained in:
@@ -92,7 +92,7 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) {
|
||||
config["federated_server_address"] = server_address;
|
||||
config["federated_world_size"] = std::string("1");
|
||||
config["federated_rank"] = Integer(0);
|
||||
auto *comm = FederatedCommunicator::Create(config);
|
||||
FederatedCommunicator::Create(config);
|
||||
};
|
||||
EXPECT_THROW(construct(), dmlc::Error);
|
||||
}
|
||||
@@ -104,7 +104,7 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) {
|
||||
config["federated_server_address"] = server_address;
|
||||
config["federated_world_size"] = 1;
|
||||
config["federated_rank"] = std::string("0");
|
||||
auto *comm = FederatedCommunicator::Create(config);
|
||||
FederatedCommunicator::Create(config);
|
||||
};
|
||||
EXPECT_THROW(construct(), dmlc::Error);
|
||||
}
|
||||
@@ -125,7 +125,7 @@ TEST(FederatedCommunicatorSimpleTest, IsDistributed) {
|
||||
TEST_F(FederatedCommunicatorTest, Allgather) {
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgather, rank, server_address_);
|
||||
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgather, rank, server_->Address());
|
||||
}
|
||||
for (auto &thread : threads) {
|
||||
thread.join();
|
||||
@@ -135,7 +135,7 @@ TEST_F(FederatedCommunicatorTest, Allgather) {
|
||||
TEST_F(FederatedCommunicatorTest, Allreduce) {
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_address_);
|
||||
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_->Address());
|
||||
}
|
||||
for (auto &thread : threads) {
|
||||
thread.join();
|
||||
@@ -145,7 +145,7 @@ TEST_F(FederatedCommunicatorTest, Allreduce) {
|
||||
TEST_F(FederatedCommunicatorTest, Broadcast) {
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||
threads.emplace_back(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_address_);
|
||||
threads.emplace_back(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_->Address());
|
||||
}
|
||||
for (auto &thread : threads) {
|
||||
thread.join();
|
||||
|
||||
Reference in New Issue
Block a user