[col] Small cleanup to federated comm. (#10397)
This commit is contained in:
parent
f5815b6982
commit
c9f5fcaf21
@ -53,8 +53,8 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
|
|||||||
args.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
|
args.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
|
||||||
auto channel = grpc::CreateCustomChannel(host + ":" + std::to_string(port),
|
auto channel = grpc::CreateCustomChannel(host + ":" + std::to_string(port),
|
||||||
grpc::SslCredentials(options), args);
|
grpc::SslCredentials(options), args);
|
||||||
channel->WaitForConnected(
|
channel->WaitForConnected(gpr_time_add(
|
||||||
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN)));
|
gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(DefaultTimeoutSec(), GPR_TIMESPAN)));
|
||||||
return federated::Federated::NewStub(channel);
|
return federated::Federated::NewStub(channel);
|
||||||
}();
|
}();
|
||||||
}
|
}
|
||||||
@ -90,8 +90,6 @@ FederatedComm::FederatedComm(std::int32_t retry, std::chrono::seconds timeout, s
|
|||||||
auto parsed = common::Split(server_address, ':');
|
auto parsed = common::Split(server_address, ':');
|
||||||
CHECK_EQ(parsed.size(), 2) << "Invalid server address:" << server_address;
|
CHECK_EQ(parsed.size(), 2) << "Invalid server address:" << server_address;
|
||||||
|
|
||||||
CHECK_NE(rank, -1) << "Parameter `federated_rank` is required";
|
|
||||||
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
|
|
||||||
CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required.";
|
CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required.";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -6,8 +6,9 @@
|
|||||||
#include <federated.grpc.pb.h>
|
#include <federated.grpc.pb.h>
|
||||||
#include <federated.pb.h>
|
#include <federated.pb.h>
|
||||||
|
|
||||||
|
#include <chrono> // for seconds
|
||||||
#include <cstdint> // for int32_t
|
#include <cstdint> // for int32_t
|
||||||
#include <memory> // for unique_ptr
|
#include <memory> // for shared_ptr
|
||||||
#include <string> // for string
|
#include <string> // for string
|
||||||
|
|
||||||
#include "../../src/collective/comm.h" // for HostComm
|
#include "../../src/collective/comm.h" // for HostComm
|
||||||
@ -46,10 +47,6 @@ class FederatedComm : public HostComm {
|
|||||||
*/
|
*/
|
||||||
explicit FederatedComm(std::int32_t retry, std::chrono::seconds timeout, std::string task_id,
|
explicit FederatedComm(std::int32_t retry, std::chrono::seconds timeout, std::string task_id,
|
||||||
Json const& config);
|
Json const& config);
|
||||||
explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world,
|
|
||||||
std::int32_t rank) {
|
|
||||||
this->Init(host, port, world, rank, {}, {}, {});
|
|
||||||
}
|
|
||||||
[[nodiscard]] Result Shutdown() final {
|
[[nodiscard]] Result Shutdown() final {
|
||||||
this->ResetState();
|
this->ResetState();
|
||||||
return Success();
|
return Success();
|
||||||
|
|||||||
@ -16,21 +16,35 @@
|
|||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
namespace {
|
namespace {
|
||||||
class FederatedCommTest : public SocketTest {};
|
class FederatedCommTest : public SocketTest {};
|
||||||
|
auto MakeConfig(std::string host, std::int32_t port, std::int32_t world, std::int32_t rank) {
|
||||||
|
Json config{Object{}};
|
||||||
|
config["federated_server_address"] = host + ":" + std::to_string(port);
|
||||||
|
config["federated_world_size"] = Integer{world};
|
||||||
|
config["federated_rank"] = Integer{rank};
|
||||||
|
return config;
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) {
|
TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) {
|
||||||
auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; };
|
auto config = MakeConfig("localhost", 0, 0, 0);
|
||||||
|
auto construct = [config] {
|
||||||
|
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
|
||||||
|
};
|
||||||
ASSERT_THAT(construct, GMockThrow("Invalid world size"));
|
ASSERT_THAT(construct, GMockThrow("Invalid world size"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(FederatedCommTest, ThrowOnRankTooSmall) {
|
TEST_F(FederatedCommTest, ThrowOnRankTooSmall) {
|
||||||
auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; };
|
auto config = MakeConfig("localhost", 0, 1, -1);
|
||||||
|
auto construct = [config] {
|
||||||
|
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
|
||||||
|
};
|
||||||
ASSERT_THAT(construct, GMockThrow("Invalid worker rank."));
|
ASSERT_THAT(construct, GMockThrow("Invalid worker rank."));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(FederatedCommTest, ThrowOnRankTooBig) {
|
TEST_F(FederatedCommTest, ThrowOnRankTooBig) {
|
||||||
auto construct = [] {
|
auto config = MakeConfig("localhost", 0, 1, 1);
|
||||||
FederatedComm comm{"localhost", 0, 1, 1};
|
auto construct = [config] {
|
||||||
|
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
|
||||||
};
|
};
|
||||||
ASSERT_THAT(construct, GMockThrow("Invalid worker rank."));
|
ASSERT_THAT(construct, GMockThrow("Invalid worker rank."));
|
||||||
}
|
}
|
||||||
@ -68,7 +82,8 @@ TEST_F(FederatedCommTest, GetWorldSizeAndRank) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(FederatedCommTest, IsDistributed) {
|
TEST_F(FederatedCommTest, IsDistributed) {
|
||||||
FederatedComm comm{"localhost", 0, 2, 1};
|
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "",
|
||||||
|
MakeConfig("localhost", 0, 2, 1)};
|
||||||
EXPECT_TRUE(comm.IsDistributed());
|
EXPECT_TRUE(comm.IsDistributed());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user