[col] Small cleanup to federated comm. (#10397)
This commit is contained in:
@@ -16,21 +16,35 @@
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class FederatedCommTest : public SocketTest {};
|
||||
auto MakeConfig(std::string host, std::int32_t port, std::int32_t world, std::int32_t rank) {
|
||||
Json config{Object{}};
|
||||
config["federated_server_address"] = host + ":" + std::to_string(port);
|
||||
config["federated_world_size"] = Integer{world};
|
||||
config["federated_rank"] = Integer{rank};
|
||||
return config;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) {
|
||||
auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; };
|
||||
auto config = MakeConfig("localhost", 0, 0, 0);
|
||||
auto construct = [config] {
|
||||
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
|
||||
};
|
||||
ASSERT_THAT(construct, GMockThrow("Invalid world size"));
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnRankTooSmall) {
|
||||
auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; };
|
||||
auto config = MakeConfig("localhost", 0, 1, -1);
|
||||
auto construct = [config] {
|
||||
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
|
||||
};
|
||||
ASSERT_THAT(construct, GMockThrow("Invalid worker rank."));
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnRankTooBig) {
|
||||
auto construct = [] {
|
||||
FederatedComm comm{"localhost", 0, 1, 1};
|
||||
auto config = MakeConfig("localhost", 0, 1, 1);
|
||||
auto construct = [config] {
|
||||
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
|
||||
};
|
||||
ASSERT_THAT(construct, GMockThrow("Invalid worker rank."));
|
||||
}
|
||||
@@ -68,7 +82,8 @@ TEST_F(FederatedCommTest, GetWorldSizeAndRank) {
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, IsDistributed) {
|
||||
FederatedComm comm{"localhost", 0, 2, 1};
|
||||
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "",
|
||||
MakeConfig("localhost", 0, 2, 1)};
|
||||
EXPECT_TRUE(comm.IsDistributed());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user