diff --git a/plugin/federated/federated_comm.cc b/plugin/federated/federated_comm.cc index ec1287413..578c25bdb 100644 --- a/plugin/federated/federated_comm.cc +++ b/plugin/federated/federated_comm.cc @@ -53,8 +53,8 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_ args.SetMaxReceiveMessageSize(std::numeric_limits::max()); auto channel = grpc::CreateCustomChannel(host + ":" + std::to_string(port), grpc::SslCredentials(options), args); - channel->WaitForConnected( - gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN))); + channel->WaitForConnected(gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(DefaultTimeoutSec(), GPR_TIMESPAN))); return federated::Federated::NewStub(channel); }(); } @@ -90,8 +90,6 @@ FederatedComm::FederatedComm(std::int32_t retry, std::chrono::seconds timeout, s auto parsed = common::Split(server_address, ':'); CHECK_EQ(parsed.size(), 2) << "Invalid server address:" << server_address; - CHECK_NE(rank, -1) << "Parameter `federated_rank` is required"; - CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required."; CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required."; /** diff --git a/plugin/federated/federated_comm.h b/plugin/federated/federated_comm.h index b39e1878a..0909509e0 100644 --- a/plugin/federated/federated_comm.h +++ b/plugin/federated/federated_comm.h @@ -6,8 +6,9 @@ #include #include +#include // for seconds #include // for int32_t -#include // for unique_ptr +#include // for shared_ptr #include // for string #include "../../src/collective/comm.h" // for HostComm @@ -46,10 +47,6 @@ class FederatedComm : public HostComm { */ explicit FederatedComm(std::int32_t retry, std::chrono::seconds timeout, std::string task_id, Json const& config); - explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world, - std::int32_t rank) { - this->Init(host, port, world, rank, {}, {}, {}); - } [[nodiscard]] Result Shutdown() final { this->ResetState(); return Success(); diff --git a/tests/cpp/plugin/federated/test_federated_comm.cc b/tests/cpp/plugin/federated/test_federated_comm.cc index 16edc685f..a82e9c2c4 100644 --- a/tests/cpp/plugin/federated/test_federated_comm.cc +++ b/tests/cpp/plugin/federated/test_federated_comm.cc @@ -16,21 +16,35 @@ namespace xgboost::collective { namespace { class FederatedCommTest : public SocketTest {}; +auto MakeConfig(std::string host, std::int32_t port, std::int32_t world, std::int32_t rank) { + Json config{Object{}}; + config["federated_server_address"] = host + ":" + std::to_string(port); + config["federated_world_size"] = Integer{world}; + config["federated_rank"] = Integer{rank}; + return config; +} } // namespace TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) { - auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; }; + auto config = MakeConfig("localhost", 0, 0, 0); + auto construct = [config] { + FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config}; + }; ASSERT_THAT(construct, GMockThrow("Invalid world size")); } TEST_F(FederatedCommTest, ThrowOnRankTooSmall) { - auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; }; + auto config = MakeConfig("localhost", 0, 1, -1); + auto construct = [config] { + FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config}; + }; ASSERT_THAT(construct, GMockThrow("Invalid worker rank.")); } TEST_F(FederatedCommTest, ThrowOnRankTooBig) { - auto construct = [] { - FederatedComm comm{"localhost", 0, 1, 1}; + auto config = MakeConfig("localhost", 0, 1, 1); + auto construct = [config] { + FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config}; }; ASSERT_THAT(construct, GMockThrow("Invalid worker rank.")); } @@ -68,7 +82,8 @@ TEST_F(FederatedCommTest, GetWorldSizeAndRank) { } TEST_F(FederatedCommTest, IsDistributed) { - FederatedComm comm{"localhost", 0, 2, 1}; + FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", + MakeConfig("localhost", 0, 2, 1)}; EXPECT_TRUE(comm.IsDistributed()); }