[col] Small cleanup to federated comm. (#10397)

This commit is contained in:
Jiaming Yuan 2024-06-07 21:19:04 +08:00 committed by GitHub
parent f5815b6982
commit c9f5fcaf21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 24 additions and 14 deletions

View File

@ -53,8 +53,8 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
args.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max()); args.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
auto channel = grpc::CreateCustomChannel(host + ":" + std::to_string(port), auto channel = grpc::CreateCustomChannel(host + ":" + std::to_string(port),
grpc::SslCredentials(options), args); grpc::SslCredentials(options), args);
channel->WaitForConnected( channel->WaitForConnected(gpr_time_add(
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN))); gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(DefaultTimeoutSec(), GPR_TIMESPAN)));
return federated::Federated::NewStub(channel); return federated::Federated::NewStub(channel);
}(); }();
} }
@ -90,8 +90,6 @@ FederatedComm::FederatedComm(std::int32_t retry, std::chrono::seconds timeout, s
auto parsed = common::Split(server_address, ':'); auto parsed = common::Split(server_address, ':');
CHECK_EQ(parsed.size(), 2) << "Invalid server address:" << server_address; CHECK_EQ(parsed.size(), 2) << "Invalid server address:" << server_address;
CHECK_NE(rank, -1) << "Parameter `federated_rank` is required";
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required."; CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required.";
/** /**

View File

@ -6,8 +6,9 @@
#include <federated.grpc.pb.h> #include <federated.grpc.pb.h>
#include <federated.pb.h> #include <federated.pb.h>
#include <chrono> // for seconds
#include <cstdint> // for int32_t #include <cstdint> // for int32_t
#include <memory> // for unique_ptr #include <memory> // for shared_ptr
#include <string> // for string #include <string> // for string
#include "../../src/collective/comm.h" // for HostComm #include "../../src/collective/comm.h" // for HostComm
@ -46,10 +47,6 @@ class FederatedComm : public HostComm {
*/ */
explicit FederatedComm(std::int32_t retry, std::chrono::seconds timeout, std::string task_id, explicit FederatedComm(std::int32_t retry, std::chrono::seconds timeout, std::string task_id,
Json const& config); Json const& config);
explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world,
std::int32_t rank) {
this->Init(host, port, world, rank, {}, {}, {});
}
[[nodiscard]] Result Shutdown() final { [[nodiscard]] Result Shutdown() final {
this->ResetState(); this->ResetState();
return Success(); return Success();

View File

@ -16,21 +16,35 @@
namespace xgboost::collective { namespace xgboost::collective {
namespace { namespace {
class FederatedCommTest : public SocketTest {}; class FederatedCommTest : public SocketTest {};
auto MakeConfig(std::string host, std::int32_t port, std::int32_t world, std::int32_t rank) {
Json config{Object{}};
config["federated_server_address"] = host + ":" + std::to_string(port);
config["federated_world_size"] = Integer{world};
config["federated_rank"] = Integer{rank};
return config;
}
} // namespace } // namespace
TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) { TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) {
auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; }; auto config = MakeConfig("localhost", 0, 0, 0);
auto construct = [config] {
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
};
ASSERT_THAT(construct, GMockThrow("Invalid world size")); ASSERT_THAT(construct, GMockThrow("Invalid world size"));
} }
TEST_F(FederatedCommTest, ThrowOnRankTooSmall) { TEST_F(FederatedCommTest, ThrowOnRankTooSmall) {
auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; }; auto config = MakeConfig("localhost", 0, 1, -1);
auto construct = [config] {
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
};
ASSERT_THAT(construct, GMockThrow("Invalid worker rank.")); ASSERT_THAT(construct, GMockThrow("Invalid worker rank."));
} }
TEST_F(FederatedCommTest, ThrowOnRankTooBig) { TEST_F(FederatedCommTest, ThrowOnRankTooBig) {
auto construct = [] { auto config = MakeConfig("localhost", 0, 1, 1);
FederatedComm comm{"localhost", 0, 1, 1}; auto construct = [config] {
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
}; };
ASSERT_THAT(construct, GMockThrow("Invalid worker rank.")); ASSERT_THAT(construct, GMockThrow("Invalid worker rank."));
} }
@ -68,7 +82,8 @@ TEST_F(FederatedCommTest, GetWorldSizeAndRank) {
} }
TEST_F(FederatedCommTest, IsDistributed) { TEST_F(FederatedCommTest, IsDistributed) {
FederatedComm comm{"localhost", 0, 2, 1}; FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "",
MakeConfig("localhost", 0, 2, 1)};
EXPECT_TRUE(comm.IsDistributed()); EXPECT_TRUE(comm.IsDistributed());
} }