Utilities and cleanups for socket. (#9576)
- Use c++-17 nodiscard and nested ns. - Add bind method to socket. - Remove rabit parameters.
This commit is contained in:
parent
5abe50ff8c
commit
b438d684d2
@ -1554,29 +1554,19 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config,
|
|||||||
* \param config JSON encoded configuration. Accepted JSON keys are:
|
* \param config JSON encoded configuration. Accepted JSON keys are:
|
||||||
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
|
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
|
||||||
* * rabit: Use Rabit. This is the default if the type is unspecified.
|
* * rabit: Use Rabit. This is the default if the type is unspecified.
|
||||||
* * mpi: Use MPI.
|
|
||||||
* * federated: Use the gRPC interface for Federated Learning.
|
* * federated: Use the gRPC interface for Federated Learning.
|
||||||
* Only applicable to the Rabit communicator (these are case-sensitive):
|
* Only applicable to the Rabit communicator (these are case-sensitive):
|
||||||
* - rabit_tracker_uri: Hostname of the tracker.
|
* - rabit_tracker_uri: Hostname of the tracker.
|
||||||
* - rabit_tracker_port: Port number of the tracker.
|
* - rabit_tracker_port: Port number of the tracker.
|
||||||
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
|
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||||
* - rabit_world_size: Total number of workers.
|
* - rabit_world_size: Total number of workers.
|
||||||
* - rabit_hadoop_mode: Enable Hadoop support.
|
|
||||||
* - rabit_tree_reduce_minsize: Minimal size for tree reduce.
|
|
||||||
* - rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
|
|
||||||
* - rabit_reduce_buffer: Size of the reduce buffer.
|
|
||||||
* - rabit_bootstrap_cache: Size of the bootstrap cache.
|
|
||||||
* - rabit_debug: Enable debugging.
|
|
||||||
* - rabit_timeout: Enable timeout.
|
* - rabit_timeout: Enable timeout.
|
||||||
* - rabit_timeout_sec: Timeout in seconds.
|
* - rabit_timeout_sec: Timeout in seconds.
|
||||||
* - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
|
|
||||||
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
|
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
|
||||||
* environment variables):
|
* environment variables):
|
||||||
* - DMLC_TRACKER_URI: Hostname of the tracker.
|
* - DMLC_TRACKER_URI: Hostname of the tracker.
|
||||||
* - DMLC_TRACKER_PORT: Port number of the tracker.
|
* - DMLC_TRACKER_PORT: Port number of the tracker.
|
||||||
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
|
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||||
* - DMLC_ROLE: Role of the current task, "worker" or "server".
|
|
||||||
* - DMLC_NUM_ATTEMPT: Number of attempts after task failure.
|
|
||||||
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
|
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
|
||||||
* Only applicable to the Federated communicator (use upper case for environment variables, use
|
* Only applicable to the Federated communicator (use upper case for environment variables, use
|
||||||
* lower case for runtime configuration):
|
* lower case for runtime configuration):
|
||||||
|
|||||||
@ -215,9 +215,9 @@ class SockAddrV4 {
|
|||||||
static SockAddrV4 Loopback();
|
static SockAddrV4 Loopback();
|
||||||
static SockAddrV4 InaddrAny();
|
static SockAddrV4 InaddrAny();
|
||||||
|
|
||||||
in_port_t Port() const { return ntohs(addr_.sin_port); }
|
[[nodiscard]] in_port_t Port() const { return ntohs(addr_.sin_port); }
|
||||||
|
|
||||||
std::string Addr() const {
|
[[nodiscard]] std::string Addr() const {
|
||||||
char buf[INET_ADDRSTRLEN];
|
char buf[INET_ADDRSTRLEN];
|
||||||
auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV4), &addr_.sin_addr,
|
auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV4), &addr_.sin_addr,
|
||||||
buf, INET_ADDRSTRLEN);
|
buf, INET_ADDRSTRLEN);
|
||||||
@ -226,7 +226,7 @@ class SockAddrV4 {
|
|||||||
}
|
}
|
||||||
return {buf};
|
return {buf};
|
||||||
}
|
}
|
||||||
sockaddr_in const &Handle() const { return addr_; }
|
[[nodiscard]] sockaddr_in const &Handle() const { return addr_; }
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -243,13 +243,13 @@ class SockAddress {
|
|||||||
explicit SockAddress(SockAddrV6 const &addr) : v6_{addr}, domain_{SockDomain::kV6} {}
|
explicit SockAddress(SockAddrV6 const &addr) : v6_{addr}, domain_{SockDomain::kV6} {}
|
||||||
explicit SockAddress(SockAddrV4 const &addr) : v4_{addr} {}
|
explicit SockAddress(SockAddrV4 const &addr) : v4_{addr} {}
|
||||||
|
|
||||||
auto Domain() const { return domain_; }
|
[[nodiscard]] auto Domain() const { return domain_; }
|
||||||
|
|
||||||
bool IsV4() const { return Domain() == SockDomain::kV4; }
|
[[nodiscard]] bool IsV4() const { return Domain() == SockDomain::kV4; }
|
||||||
bool IsV6() const { return !IsV4(); }
|
[[nodiscard]] bool IsV6() const { return !IsV4(); }
|
||||||
|
|
||||||
auto const &V4() const { return v4_; }
|
[[nodiscard]] auto const &V4() const { return v4_; }
|
||||||
auto const &V6() const { return v6_; }
|
[[nodiscard]] auto const &V6() const { return v6_; }
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -261,6 +261,7 @@ class TCPSocket {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
HandleT handle_{InvalidSocket()};
|
HandleT handle_{InvalidSocket()};
|
||||||
|
bool non_blocking_{false};
|
||||||
// There's reliable no way to extract domain from a socket without first binding that
|
// There's reliable no way to extract domain from a socket without first binding that
|
||||||
// socket on macos.
|
// socket on macos.
|
||||||
#if defined(__APPLE__)
|
#if defined(__APPLE__)
|
||||||
@ -276,7 +277,7 @@ class TCPSocket {
|
|||||||
/**
|
/**
|
||||||
* \brief Return the socket domain.
|
* \brief Return the socket domain.
|
||||||
*/
|
*/
|
||||||
auto Domain() const -> SockDomain {
|
[[nodiscard]] auto Domain() const -> SockDomain {
|
||||||
auto ret_iafamily = [](std::int32_t domain) {
|
auto ret_iafamily = [](std::int32_t domain) {
|
||||||
switch (domain) {
|
switch (domain) {
|
||||||
case AF_INET:
|
case AF_INET:
|
||||||
@ -321,10 +322,10 @@ class TCPSocket {
|
|||||||
#endif // platforms
|
#endif // platforms
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsClosed() const { return handle_ == InvalidSocket(); }
|
[[nodiscard]] bool IsClosed() const { return handle_ == InvalidSocket(); }
|
||||||
|
|
||||||
/** \brief get last error code if any */
|
/** @brief get last error code if any */
|
||||||
Result GetSockError() const {
|
[[nodiscard]] Result GetSockError() const {
|
||||||
std::int32_t optval = 0;
|
std::int32_t optval = 0;
|
||||||
socklen_t len = sizeof(optval);
|
socklen_t len = sizeof(optval);
|
||||||
auto ret = getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&optval), &len);
|
auto ret = getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&optval), &len);
|
||||||
@ -340,7 +341,7 @@ class TCPSocket {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** \brief check if anything bad happens */
|
/** \brief check if anything bad happens */
|
||||||
bool BadSocket() const {
|
[[nodiscard]] bool BadSocket() const {
|
||||||
if (IsClosed()) {
|
if (IsClosed()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -352,24 +353,56 @@ class TCPSocket {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetNonBlock(bool non_block) {
|
[[nodiscard]] Result NonBlocking(bool non_block) {
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
u_long mode = non_block ? 1 : 0;
|
u_long mode = non_block ? 1 : 0;
|
||||||
xgboost_CHECK_SYS_CALL(ioctlsocket(handle_, FIONBIO, &mode), NO_ERROR);
|
if (ioctlsocket(handle_, FIONBIO, &mode) != NO_ERROR) {
|
||||||
|
return system::FailWithCode("Failed to set socket to non-blocking.");
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
std::int32_t flag = fcntl(handle_, F_GETFL, 0);
|
std::int32_t flag = fcntl(handle_, F_GETFL, 0);
|
||||||
if (flag == -1) {
|
auto rc = flag;
|
||||||
system::ThrowAtError("fcntl");
|
if (rc == -1) {
|
||||||
|
return system::FailWithCode("Failed to get socket flag.");
|
||||||
}
|
}
|
||||||
if (non_block) {
|
if (non_block) {
|
||||||
flag |= O_NONBLOCK;
|
flag |= O_NONBLOCK;
|
||||||
} else {
|
} else {
|
||||||
flag &= ~O_NONBLOCK;
|
flag &= ~O_NONBLOCK;
|
||||||
}
|
}
|
||||||
if (fcntl(handle_, F_SETFL, flag) == -1) {
|
rc = fcntl(handle_, F_SETFL, flag);
|
||||||
system::ThrowAtError("fcntl");
|
if (rc == -1) {
|
||||||
|
return system::FailWithCode("Failed to set socket to non-blocking.");
|
||||||
}
|
}
|
||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
|
non_blocking_ = non_block;
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
[[nodiscard]] bool NonBlocking() const { return non_blocking_; }
|
||||||
|
[[nodiscard]] Result RecvTimeout(std::chrono::seconds timeout) {
|
||||||
|
timeval tv;
|
||||||
|
tv.tv_sec = timeout.count();
|
||||||
|
tv.tv_usec = 0;
|
||||||
|
auto rc = setsockopt(Handle(), SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char const *>(&tv),
|
||||||
|
sizeof(tv));
|
||||||
|
if (rc != 0) {
|
||||||
|
return system::FailWithCode("Failed to set timeout on recv.");
|
||||||
|
}
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Result SetBufSize(std::int32_t n_bytes) {
|
||||||
|
auto rc = setsockopt(this->Handle(), SOL_SOCKET, SO_SNDBUF, reinterpret_cast<char *>(&n_bytes),
|
||||||
|
sizeof(n_bytes));
|
||||||
|
if (rc != 0) {
|
||||||
|
return system::FailWithCode("Failed to set send buffer size.");
|
||||||
|
}
|
||||||
|
rc = setsockopt(this->Handle(), SOL_SOCKET, SO_RCVBUF, reinterpret_cast<char *>(&n_bytes),
|
||||||
|
sizeof(n_bytes));
|
||||||
|
if (rc != 0) {
|
||||||
|
return system::FailWithCode("Failed to set recv buffer size.");
|
||||||
|
}
|
||||||
|
return Success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetKeepAlive() {
|
void SetKeepAlive() {
|
||||||
@ -391,7 +424,7 @@ class TCPSocket {
|
|||||||
* \brief Accept new connection, returns a new TCP socket for the new connection.
|
* \brief Accept new connection, returns a new TCP socket for the new connection.
|
||||||
*/
|
*/
|
||||||
TCPSocket Accept() {
|
TCPSocket Accept() {
|
||||||
HandleT newfd = accept(handle_, nullptr, nullptr);
|
HandleT newfd = accept(Handle(), nullptr, nullptr);
|
||||||
if (newfd == InvalidSocket()) {
|
if (newfd == InvalidSocket()) {
|
||||||
system::ThrowAtError("accept");
|
system::ThrowAtError("accept");
|
||||||
}
|
}
|
||||||
@ -399,6 +432,18 @@ class TCPSocket {
|
|||||||
return newsock;
|
return newsock;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Result Accept(TCPSocket *out, SockAddrV4 *addr) {
|
||||||
|
struct sockaddr_in caddr;
|
||||||
|
socklen_t caddr_len = sizeof(caddr);
|
||||||
|
HandleT newfd = accept(Handle(), reinterpret_cast<sockaddr *>(&caddr), &caddr_len);
|
||||||
|
if (newfd == InvalidSocket()) {
|
||||||
|
return system::FailWithCode("Failed to accept.");
|
||||||
|
}
|
||||||
|
*addr = SockAddrV4{caddr};
|
||||||
|
*out = TCPSocket{newfd};
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
|
||||||
~TCPSocket() {
|
~TCPSocket() {
|
||||||
if (!IsClosed()) {
|
if (!IsClosed()) {
|
||||||
Close();
|
Close();
|
||||||
@ -413,9 +458,9 @@ class TCPSocket {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* \brief Return the native socket file descriptor.
|
* @brief Return the native socket file descriptor.
|
||||||
*/
|
*/
|
||||||
HandleT const &Handle() const { return handle_; }
|
[[nodiscard]] HandleT const &Handle() const { return handle_; }
|
||||||
/**
|
/**
|
||||||
* \brief Listen to incoming requests. Should be called after bind.
|
* \brief Listen to incoming requests. Should be called after bind.
|
||||||
*/
|
*/
|
||||||
@ -448,6 +493,49 @@ class TCPSocket {
|
|||||||
return ntohs(res_addr.sin_port);
|
return ntohs(res_addr.sin_port);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] auto Port() const {
|
||||||
|
if (this->Domain() == SockDomain::kV4) {
|
||||||
|
sockaddr_in res_addr;
|
||||||
|
socklen_t addrlen = sizeof(res_addr);
|
||||||
|
auto code = getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen);
|
||||||
|
if (code != 0) {
|
||||||
|
return std::make_pair(system::FailWithCode("getsockname"), std::int32_t{0});
|
||||||
|
}
|
||||||
|
return std::make_pair(Success(), std::int32_t{ntohs(res_addr.sin_port)});
|
||||||
|
} else {
|
||||||
|
sockaddr_in6 res_addr;
|
||||||
|
socklen_t addrlen = sizeof(res_addr);
|
||||||
|
auto code = getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen);
|
||||||
|
if (code != 0) {
|
||||||
|
return std::make_pair(system::FailWithCode("getsockname"), std::int32_t{0});
|
||||||
|
}
|
||||||
|
return std::make_pair(Success(), std::int32_t{ntohs(res_addr.sin6_port)});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Result Bind(StringView ip, std::int32_t *port) {
|
||||||
|
// bind socket handle_ to ip
|
||||||
|
auto addr = MakeSockAddress(ip, 0);
|
||||||
|
std::int32_t errc{0};
|
||||||
|
if (addr.IsV4()) {
|
||||||
|
auto handle = reinterpret_cast<sockaddr const *>(&addr.V4().Handle());
|
||||||
|
errc = bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.V4().Handle())>));
|
||||||
|
} else {
|
||||||
|
auto handle = reinterpret_cast<sockaddr const *>(&addr.V6().Handle());
|
||||||
|
errc = bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.V6().Handle())>));
|
||||||
|
}
|
||||||
|
if (errc != 0) {
|
||||||
|
return system::FailWithCode("Failed to bind socket.");
|
||||||
|
}
|
||||||
|
auto [rc, new_port] = this->Port();
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return std::move(rc);
|
||||||
|
}
|
||||||
|
*port = new_port;
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Send data, without error then all data should be sent.
|
* \brief Send data, without error then all data should be sent.
|
||||||
*/
|
*/
|
||||||
@ -567,13 +655,9 @@ class TCPSocket {
|
|||||||
xgboost::collective::TCPSocket *out_conn);
|
xgboost::collective::TCPSocket *out_conn);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Get the local host name.
|
* @brief Get the local host name.
|
||||||
*/
|
*/
|
||||||
inline std::string GetHostName() {
|
[[nodiscard]] Result GetHostName(std::string *p_out);
|
||||||
char buf[HOST_NAME_MAX];
|
|
||||||
xgboost_CHECK_SYS_CALL(gethostname(&buf[0], HOST_NAME_MAX), 0);
|
|
||||||
return buf;
|
|
||||||
}
|
|
||||||
} // namespace collective
|
} // namespace collective
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
|||||||
@ -372,6 +372,19 @@ class Json {
|
|||||||
/*! \brief Use your own JsonWriter. */
|
/*! \brief Use your own JsonWriter. */
|
||||||
static void Dump(Json json, JsonWriter* writer);
|
static void Dump(Json json, JsonWriter* writer);
|
||||||
|
|
||||||
|
template <typename Container = std::string>
|
||||||
|
static Container Dump(Json json) {
|
||||||
|
if constexpr (std::is_same_v<Container, std::string>) {
|
||||||
|
std::string str;
|
||||||
|
Dump(json, &str);
|
||||||
|
return str;
|
||||||
|
} else {
|
||||||
|
std::vector<char> str;
|
||||||
|
Dump(json, &str);
|
||||||
|
return str;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Json() = default;
|
Json() = default;
|
||||||
|
|
||||||
// number
|
// number
|
||||||
|
|||||||
@ -29,7 +29,7 @@ struct StringView {
|
|||||||
public:
|
public:
|
||||||
constexpr StringView() = default;
|
constexpr StringView() = default;
|
||||||
constexpr StringView(CharT const* str, std::size_t size) : str_{str}, size_{size} {}
|
constexpr StringView(CharT const* str, std::size_t size) : str_{str}, size_{size} {}
|
||||||
explicit StringView(std::string const& str) : str_{str.c_str()}, size_{str.size()} {}
|
StringView(std::string const& str) : str_{str.c_str()}, size_{str.size()} {} // NOLINT
|
||||||
constexpr StringView(CharT const* str) // NOLINT
|
constexpr StringView(CharT const* str) // NOLINT
|
||||||
: str_{str}, size_{str == nullptr ? 0ul : Traits::length(str)} {}
|
: str_{str}, size_{str == nullptr ? 0ul : Traits::length(str)} {}
|
||||||
|
|
||||||
|
|||||||
@ -11,9 +11,7 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::federated {
|
||||||
namespace federated {
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief A wrapper around the gRPC client.
|
* @brief A wrapper around the gRPC client.
|
||||||
*/
|
*/
|
||||||
@ -112,6 +110,4 @@ class FederatedClient {
|
|||||||
int const rank_;
|
int const rank_;
|
||||||
uint64_t sequence_number_{};
|
uint64_t sequence_number_{};
|
||||||
};
|
};
|
||||||
|
} // namespace xgboost::federated
|
||||||
} // namespace federated
|
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -9,9 +9,7 @@
|
|||||||
#include "../../src/common/io.h"
|
#include "../../src/common/io.h"
|
||||||
#include "federated_client.h"
|
#include "federated_client.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::collective {
|
||||||
namespace collective {
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief A Federated Learning communicator class that handles collective communication.
|
* @brief A Federated Learning communicator class that handles collective communication.
|
||||||
*/
|
*/
|
||||||
@ -118,13 +116,13 @@ class FederatedCommunicator : public Communicator {
|
|||||||
* \brief Get if the communicator is distributed.
|
* \brief Get if the communicator is distributed.
|
||||||
* \return True.
|
* \return True.
|
||||||
*/
|
*/
|
||||||
bool IsDistributed() const override { return true; }
|
[[nodiscard]] bool IsDistributed() const override { return true; }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Get if the communicator is federated.
|
* \brief Get if the communicator is federated.
|
||||||
* \return True.
|
* \return True.
|
||||||
*/
|
*/
|
||||||
bool IsFederated() const override { return true; }
|
[[nodiscard]] bool IsFederated() const override { return true; }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Perform in-place allgather.
|
* \brief Perform in-place allgather.
|
||||||
@ -189,5 +187,4 @@ class FederatedCommunicator : public Communicator {
|
|||||||
private:
|
private:
|
||||||
std::unique_ptr<xgboost::federated::FederatedClient> client_{};
|
std::unique_ptr<xgboost::federated::FederatedClient> client_{};
|
||||||
};
|
};
|
||||||
} // namespace collective
|
} // namespace xgboost::collective
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -11,9 +11,7 @@
|
|||||||
|
|
||||||
#include "../../src/common/io.h"
|
#include "../../src/common/io.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::federated {
|
||||||
namespace federated {
|
|
||||||
|
|
||||||
grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request,
|
grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request,
|
||||||
AllgatherReply* reply) {
|
AllgatherReply* reply) {
|
||||||
handler_.Allgather(request->send_buffer().data(), request->send_buffer().size(),
|
handler_.Allgather(request->send_buffer().data(), request->send_buffer().size(),
|
||||||
@ -75,6 +73,4 @@ void RunInsecureServer(int port, int world_size) {
|
|||||||
|
|
||||||
server->Wait();
|
server->Wait();
|
||||||
}
|
}
|
||||||
|
} // namespace xgboost::federated
|
||||||
} // namespace federated
|
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -115,9 +115,12 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
|
|||||||
// start socket
|
// start socket
|
||||||
xgboost::system::SocketStartup();
|
xgboost::system::SocketStartup();
|
||||||
utils::Assert(all_links.size() == 0, "can only call Init once");
|
utils::Assert(all_links.size() == 0, "can only call Init once");
|
||||||
this->host_uri = xgboost::collective::GetHostName();
|
auto rc = xgboost::collective::GetHostName(&this->host_uri);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
LOG(FATAL) << rc.Report();
|
||||||
|
}
|
||||||
// get information from tracker
|
// get information from tracker
|
||||||
auto rc = this->ReConnectLinks();
|
rc = this->ReConnectLinks();
|
||||||
if (rc.OK()) {
|
if (rc.OK()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -406,13 +409,14 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
if (!match) all_links.emplace_back(std::move(r));
|
if (!match) all_links.emplace_back(std::move(r));
|
||||||
}
|
}
|
||||||
sock_listen.Close();
|
sock_listen.Close();
|
||||||
|
|
||||||
this->parent_index = -1;
|
this->parent_index = -1;
|
||||||
// setup tree links and ring structure
|
// setup tree links and ring structure
|
||||||
tree_links.plinks.clear();
|
tree_links.plinks.clear();
|
||||||
for (auto &all_link : all_links) {
|
for (auto &all_link : all_links) {
|
||||||
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
|
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
|
||||||
// set the socket to non-blocking mode, enable TCP keepalive
|
// set the socket to non-blocking mode, enable TCP keepalive
|
||||||
all_link.sock.SetNonBlock(true);
|
CHECK(all_link.sock.NonBlocking(true).OK());
|
||||||
all_link.sock.SetKeepAlive();
|
all_link.sock.SetKeepAlive();
|
||||||
if (rabit_enable_tcp_no_delay) {
|
if (rabit_enable_tcp_no_delay) {
|
||||||
all_link.sock.SetNoDelay();
|
all_link.sock.SetNoDelay();
|
||||||
|
|||||||
@ -11,9 +11,7 @@
|
|||||||
#include "../../plugin/federated/federated_communicator.h"
|
#include "../../plugin/federated/federated_communicator.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::collective {
|
||||||
namespace collective {
|
|
||||||
|
|
||||||
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
|
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
|
||||||
thread_local CommunicatorType Communicator::type_{};
|
thread_local CommunicatorType Communicator::type_{};
|
||||||
|
|
||||||
@ -57,6 +55,4 @@ void Communicator::Finalize() {
|
|||||||
communicator_.reset(new NoOpCommunicator());
|
communicator_.reset(new NoOpCommunicator());
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
} // namespace xgboost::collective
|
||||||
} // namespace collective
|
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
*/
|
*/
|
||||||
#include "xgboost/collective/socket.h"
|
#include "xgboost/collective/socket.h"
|
||||||
|
|
||||||
|
#include <array> // for array
|
||||||
#include <cstddef> // std::size_t
|
#include <cstddef> // std::size_t
|
||||||
#include <cstdint> // std::int32_t
|
#include <cstdint> // std::int32_t
|
||||||
#include <cstring> // std::memcpy, std::memset
|
#include <cstring> // std::memcpy, std::memset
|
||||||
@ -92,13 +93,18 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
|
|||||||
|
|
||||||
conn = TCPSocket::Create(addr.Domain());
|
conn = TCPSocket::Create(addr.Domain());
|
||||||
CHECK_EQ(static_cast<std::int32_t>(conn.Domain()), static_cast<std::int32_t>(addr.Domain()));
|
CHECK_EQ(static_cast<std::int32_t>(conn.Domain()), static_cast<std::int32_t>(addr.Domain()));
|
||||||
conn.SetNonBlock(true);
|
auto non_blocking = conn.NonBlocking();
|
||||||
|
auto rc = conn.NonBlocking(true);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return Fail("Failed to set socket option.", std::move(rc));
|
||||||
|
}
|
||||||
|
|
||||||
Result last_error;
|
Result last_error;
|
||||||
auto log_failure = [&host, &last_error](Result err, char const *file, std::int32_t line) {
|
auto log_failure = [&host, &last_error, port](Result err, char const *file, std::int32_t line) {
|
||||||
last_error = std::move(err);
|
last_error = std::move(err);
|
||||||
LOG(WARNING) << std::filesystem::path{file}.filename().string() << "(" << line
|
LOG(WARNING) << std::filesystem::path{file}.filename().string() << "(" << line
|
||||||
<< "): Failed to connect to:" << host << " Error:" << last_error.Report();
|
<< "): Failed to connect to:" << host << ":" << port
|
||||||
|
<< " Error:" << last_error.Report();
|
||||||
};
|
};
|
||||||
|
|
||||||
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
|
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
|
||||||
@ -138,12 +144,9 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.SetNonBlock(false);
|
return conn.NonBlocking(non_blocking);
|
||||||
return Success();
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
conn.SetNonBlock(false);
|
return conn.NonBlocking(non_blocking);
|
||||||
return Success();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -152,4 +155,13 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
|
|||||||
conn.Close();
|
conn.Close();
|
||||||
return Fail(ss.str(), std::move(last_error));
|
return Fail(ss.str(), std::move(last_error));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Result GetHostName(std::string *p_out) {
|
||||||
|
std::array<char, HOST_NAME_MAX> buf;
|
||||||
|
if (gethostname(&buf[0], HOST_NAME_MAX) != 0) {
|
||||||
|
return system::FailWithCode("Failed to get host name.");
|
||||||
|
}
|
||||||
|
*p_out = buf.data();
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -73,4 +73,15 @@ TEST(Socket, Basic) {
|
|||||||
|
|
||||||
system::SocketFinalize();
|
system::SocketFinalize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Socket, Bind) {
|
||||||
|
system::SocketStartup();
|
||||||
|
auto any = SockAddrV4::InaddrAny().Addr();
|
||||||
|
auto sock = TCPSocket::Create(SockDomain::kV4);
|
||||||
|
std::int32_t port{0};
|
||||||
|
auto rc = sock.Bind(any, &port);
|
||||||
|
ASSERT_TRUE(rc.OK());
|
||||||
|
ASSERT_NE(port, 0);
|
||||||
|
system::SocketFinalize();
|
||||||
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <iterator> // for back_inserter
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
#include "../../../src/common/charconv.h"
|
#include "../../../src/common/charconv.h"
|
||||||
@ -691,4 +692,16 @@ TEST(Json, TypeCheck) {
|
|||||||
ASSERT_NE(err.find("foo"), std::string::npos);
|
ASSERT_NE(err.find("foo"), std::string::npos);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Json, Dump) {
|
||||||
|
auto str = GetModelStr();
|
||||||
|
auto jobj = Json::Load(str);
|
||||||
|
std::string result_s = Json::Dump(jobj);
|
||||||
|
|
||||||
|
std::vector<char> result_v = Json::Dump<std::vector<char>>(jobj);
|
||||||
|
ASSERT_EQ(result_s.size(), result_v.size());
|
||||||
|
for (std::size_t i = 0; i < result_s.size(); ++i) {
|
||||||
|
ASSERT_EQ(result_s[i], result_v[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user