[col] Small cleanup to federated comm. (#10397)

This commit is contained in:
Jiaming Yuan
2024-06-07 21:19:04 +08:00
committed by GitHub
parent f5815b6982
commit c9f5fcaf21
3 changed files with 24 additions and 14 deletions

View File

@@ -16,21 +16,35 @@
namespace xgboost::collective {
namespace {
class FederatedCommTest : public SocketTest {};
auto MakeConfig(std::string host, std::int32_t port, std::int32_t world, std::int32_t rank) {
Json config{Object{}};
config["federated_server_address"] = host + ":" + std::to_string(port);
config["federated_world_size"] = Integer{world};
config["federated_rank"] = Integer{rank};
return config;
}
} // namespace
TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) {
auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; };
auto config = MakeConfig("localhost", 0, 0, 0);
auto construct = [config] {
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
};
ASSERT_THAT(construct, GMockThrow("Invalid world size"));
}
TEST_F(FederatedCommTest, ThrowOnRankTooSmall) {
auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; };
auto config = MakeConfig("localhost", 0, 1, -1);
auto construct = [config] {
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
};
ASSERT_THAT(construct, GMockThrow("Invalid worker rank."));
}
TEST_F(FederatedCommTest, ThrowOnRankTooBig) {
auto construct = [] {
FederatedComm comm{"localhost", 0, 1, 1};
auto config = MakeConfig("localhost", 0, 1, 1);
auto construct = [config] {
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
};
ASSERT_THAT(construct, GMockThrow("Invalid worker rank."));
}
@@ -68,7 +82,8 @@ TEST_F(FederatedCommTest, GetWorldSizeAndRank) {
}
TEST_F(FederatedCommTest, IsDistributed) {
FederatedComm comm{"localhost", 0, 2, 1};
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "",
MakeConfig("localhost", 0, 2, 1)};
EXPECT_TRUE(comm.IsDistributed());
}