Initial support for IPv6 (#8225)
- Merge rabit socket into XGBoost. - Dask interface support. - Add test to the socket.
This commit is contained in:
parent
7d43e74e71
commit
b791446623
@ -91,6 +91,9 @@
|
|||||||
#include "../src/common/timer.cc"
|
#include "../src/common/timer.cc"
|
||||||
#include "../src/common/version.cc"
|
#include "../src/common/version.cc"
|
||||||
|
|
||||||
|
// collective
|
||||||
|
#include "../src/collective/socket.cc"
|
||||||
|
|
||||||
// c_api
|
// c_api
|
||||||
#include "../src/c_api/c_api.cc"
|
#include "../src/c_api/c_api.cc"
|
||||||
#include "../src/c_api/c_api_error.cc"
|
#include "../src/c_api/c_api_error.cc"
|
||||||
|
|||||||
@ -204,7 +204,7 @@ latex_documents = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
intersphinx_mapping = {
|
intersphinx_mapping = {
|
||||||
"python": ("https://docs.python.org/3.6", None),
|
"python": ("https://docs.python.org/3.8", None),
|
||||||
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
|
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
|
||||||
"scipy": ("https://docs.scipy.org/doc/scipy/reference/", None),
|
"scipy": ("https://docs.scipy.org/doc/scipy/reference/", None),
|
||||||
"pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None),
|
"pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None),
|
||||||
|
|||||||
@ -474,7 +474,6 @@ interface, including callback functions, custom evaluation metric and objective:
|
|||||||
callbacks=[early_stop],
|
callbacks=[early_stop],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
.. _tracker-ip:
|
.. _tracker-ip:
|
||||||
|
|
||||||
***************
|
***************
|
||||||
@ -504,6 +503,35 @@ dask config is used:
|
|||||||
reg = dxgb.DaskXGBRegressor()
|
reg = dxgb.DaskXGBRegressor()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
************
|
||||||
|
IPv6 Support
|
||||||
|
************
|
||||||
|
|
||||||
|
.. versionadded:: 2.0.0
|
||||||
|
|
||||||
|
XGBoost has initial IPv6 support for the dask interface on Linux. Due to most of the
|
||||||
|
cluster support for IPv6 is partial (dual stack instead of IPv6 only), we require
|
||||||
|
additional user configuration similar to :ref:`tracker-ip` to help XGBoost obtain the
|
||||||
|
correct address information:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import dask
|
||||||
|
from distributed import Client
|
||||||
|
from xgboost import dask as dxgb
|
||||||
|
# let xgboost know the scheduler address, use the same bracket format as dask.
|
||||||
|
with dask.config.set({"xgboost.scheduler_address": "[fd20:b6f:f759:9800::]"}):
|
||||||
|
with Client("[fd20:b6f:f759:9800::]") as client:
|
||||||
|
reg = dxgb.DaskXGBRegressor(tree_method="hist")
|
||||||
|
|
||||||
|
|
||||||
|
When GPU is used, XGBoost employs `NCCL <https://developer.nvidia.com/nccl>`_ as the
|
||||||
|
underlying communication framework, which may require some additional configuration via
|
||||||
|
environment variable depending on the setting of the cluster. Please note that IPv6
|
||||||
|
support is Unix only.
|
||||||
|
|
||||||
|
|
||||||
*****************************************************************************
|
*****************************************************************************
|
||||||
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
|
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
|
||||||
*****************************************************************************
|
*****************************************************************************
|
||||||
|
|||||||
536
include/xgboost/collective/socket.h
Normal file
536
include/xgboost/collective/socket.h
Normal file
@ -0,0 +1,536 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright (c) 2022 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#if !defined(NOMINMAX) && defined(_WIN32)
|
||||||
|
#define NOMINMAX
|
||||||
|
#endif // !defined(NOMINMAX)
|
||||||
|
|
||||||
|
#include <cerrno> // errno, EINTR, EBADF
|
||||||
|
#include <climits> // HOST_NAME_MAX
|
||||||
|
#include <cstddef> // std::size_t
|
||||||
|
#include <cstdint> // std::int32_t, std::uint16_t
|
||||||
|
#include <cstring> // memset
|
||||||
|
#include <limits> // std::numeric_limits
|
||||||
|
#include <string> // std::string
|
||||||
|
#include <system_error> // std::error_code, std::system_category
|
||||||
|
#include <utility> // std::swap
|
||||||
|
|
||||||
|
#if !defined(xgboost_IS_MINGW)
|
||||||
|
#define xgboost_IS_MINGW() defined(__MINGW32__)
|
||||||
|
#endif // xgboost_IS_MINGW
|
||||||
|
|
||||||
|
#if defined(_WIN32)
|
||||||
|
|
||||||
|
#include <winsock2.h>
|
||||||
|
#include <ws2tcpip.h>
|
||||||
|
|
||||||
|
using in_port_t = std::uint16_t;
|
||||||
|
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
#pragma comment(lib, "Ws2_32.lib")
|
||||||
|
#endif // _MSC_VER
|
||||||
|
|
||||||
|
#if !xgboost_IS_MINGW()
|
||||||
|
using ssize_t = int;
|
||||||
|
#endif // !xgboost_IS_MINGW()
|
||||||
|
|
||||||
|
#else // UNIX
|
||||||
|
|
||||||
|
#include <arpa/inet.h> // inet_ntop
|
||||||
|
#include <fcntl.h> // fcntl, F_GETFL, O_NONBLOCK
|
||||||
|
#include <netinet/in.h> // sockaddr_in6, sockaddr_in, in_port_t, INET6_ADDRSTRLEN, INET_ADDRSTRLEN
|
||||||
|
#include <netinet/in.h> // IPPROTO_TCP
|
||||||
|
#include <netinet/tcp.h> // TCP_NODELAY
|
||||||
|
#include <sys/socket.h> // socket, SOL_SOCKET, SO_ERROR, MSG_WAITALL, recv, send, AF_INET6, AF_INET
|
||||||
|
#include <unistd.h> // close
|
||||||
|
|
||||||
|
#if defined(__sun) || defined(sun)
|
||||||
|
#include <sys/sockio.h>
|
||||||
|
#endif // defined(__sun) || defined(sun)
|
||||||
|
|
||||||
|
#endif // defined(_WIN32)
|
||||||
|
|
||||||
|
#include "xgboost/base.h" // XGBOOST_EXPECT
|
||||||
|
#include "xgboost/logging.h" // LOG
|
||||||
|
#include "xgboost/string_view.h" // StringView
|
||||||
|
|
||||||
|
#if !defined(HOST_NAME_MAX)
|
||||||
|
#define HOST_NAME_MAX 256 // macos
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
|
||||||
|
#if xgboost_IS_MINGW()
|
||||||
|
// see the dummy implementation of `poll` in rabit for more info.
|
||||||
|
inline void MingWError() { LOG(FATAL) << "Distributed training on mingw is not supported."; }
|
||||||
|
#endif // xgboost_IS_MINGW()
|
||||||
|
|
||||||
|
namespace system {
|
||||||
|
inline std::int32_t LastError() {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
return WSAGetLastError();
|
||||||
|
#else
|
||||||
|
int errsv = errno;
|
||||||
|
return errsv;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(__GLIBC__)
|
||||||
|
inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError(),
|
||||||
|
std::int32_t line = __builtin_LINE(),
|
||||||
|
char const *file = __builtin_FILE()) {
|
||||||
|
auto err = std::error_code{errsv, std::system_category()};
|
||||||
|
LOG(FATAL) << "\n"
|
||||||
|
<< file << "(" << line << "): Failed to call `" << fn_name << "`: " << err.message()
|
||||||
|
<< std::endl;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError()) {
|
||||||
|
auto err = std::error_code{errsv, std::system_category()};
|
||||||
|
LOG(FATAL) << "Failed to call `" << fn_name << "`: " << err.message() << std::endl;
|
||||||
|
}
|
||||||
|
#endif // defined(__GLIBC__)
|
||||||
|
|
||||||
|
#if defined(_WIN32)
|
||||||
|
using SocketT = SOCKET;
|
||||||
|
#else
|
||||||
|
using SocketT = int;
|
||||||
|
#endif // defined(_WIN32)
|
||||||
|
|
||||||
|
#if !defined(xgboost_CHECK_SYS_CALL)
|
||||||
|
#define xgboost_CHECK_SYS_CALL(exp, expected) \
|
||||||
|
do { \
|
||||||
|
if (XGBOOST_EXPECT((exp) != (expected), false)) { \
|
||||||
|
::xgboost::system::ThrowAtError(#exp); \
|
||||||
|
} \
|
||||||
|
} while (false)
|
||||||
|
#endif // !defined(xgboost_CHECK_SYS_CALL)
|
||||||
|
|
||||||
|
inline std::int32_t CloseSocket(SocketT fd) {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
return closesocket(fd);
|
||||||
|
#else
|
||||||
|
return close(fd);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool LastErrorWouldBlock() {
|
||||||
|
int errsv = LastError();
|
||||||
|
#ifdef _WIN32
|
||||||
|
return errsv == WSAEWOULDBLOCK;
|
||||||
|
#else
|
||||||
|
return errsv == EAGAIN || errsv == EWOULDBLOCK;
|
||||||
|
#endif // _WIN32
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void SocketStartup() {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
WSADATA wsa_data;
|
||||||
|
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
|
||||||
|
ThrowAtError("WSAStartup");
|
||||||
|
}
|
||||||
|
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
|
||||||
|
WSACleanup();
|
||||||
|
LOG(FATAL) << "Could not find a usable version of Winsock.dll";
|
||||||
|
}
|
||||||
|
#endif // defined(_WIN32)
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void SocketFinalize() {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
WSACleanup();
|
||||||
|
#endif // defined(_WIN32)
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(_WIN32) && xgboost_IS_MINGW()
|
||||||
|
// dummy definition for old mysys32.
|
||||||
|
inline const char *inet_ntop(int, const void *, char *, socklen_t) { // NOLINT
|
||||||
|
MingWError();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
using ::inet_ntop;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace system
|
||||||
|
|
||||||
|
namespace collective {
|
||||||
|
class SockAddress;
|
||||||
|
|
||||||
|
enum class SockDomain : std::int32_t { kV4 = AF_INET, kV6 = AF_INET6 };
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Parse host address and return a SockAddress instance. Supports IPv4 and IPv6
|
||||||
|
* host.
|
||||||
|
*/
|
||||||
|
SockAddress MakeSockAddress(StringView host, in_port_t port);
|
||||||
|
|
||||||
|
class SockAddrV6 {
|
||||||
|
sockaddr_in6 addr_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit SockAddrV6(sockaddr_in6 addr) : addr_{addr} {}
|
||||||
|
SockAddrV6() { std::memset(&addr_, '\0', sizeof(addr_)); }
|
||||||
|
|
||||||
|
static SockAddrV6 Loopback();
|
||||||
|
static SockAddrV6 InaddrAny();
|
||||||
|
|
||||||
|
in_port_t Port() const { return ntohs(addr_.sin6_port); }
|
||||||
|
|
||||||
|
std::string Addr() const {
|
||||||
|
char buf[INET6_ADDRSTRLEN];
|
||||||
|
auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV6), &addr_.sin6_addr,
|
||||||
|
buf, INET6_ADDRSTRLEN);
|
||||||
|
if (s == nullptr) {
|
||||||
|
system::ThrowAtError("inet_ntop");
|
||||||
|
}
|
||||||
|
return {buf};
|
||||||
|
}
|
||||||
|
sockaddr_in6 const &Handle() const { return addr_; }
|
||||||
|
};
|
||||||
|
|
||||||
|
class SockAddrV4 {
|
||||||
|
private:
|
||||||
|
sockaddr_in addr_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit SockAddrV4(sockaddr_in addr) : addr_{addr} {}
|
||||||
|
SockAddrV4() { std::memset(&addr_, '\0', sizeof(addr_)); }
|
||||||
|
|
||||||
|
static SockAddrV4 Loopback();
|
||||||
|
static SockAddrV4 InaddrAny();
|
||||||
|
|
||||||
|
in_port_t Port() const { return ntohs(addr_.sin_port); }
|
||||||
|
|
||||||
|
std::string Addr() const {
|
||||||
|
char buf[INET_ADDRSTRLEN];
|
||||||
|
auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV4), &addr_.sin_addr,
|
||||||
|
buf, INET_ADDRSTRLEN);
|
||||||
|
if (s == nullptr) {
|
||||||
|
system::ThrowAtError("inet_ntop");
|
||||||
|
}
|
||||||
|
return {buf};
|
||||||
|
}
|
||||||
|
sockaddr_in const &Handle() const { return addr_; }
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Address for TCP socket, can be either IPv4 or IPv6.
|
||||||
|
*/
|
||||||
|
class SockAddress {
|
||||||
|
private:
|
||||||
|
SockAddrV6 v6_;
|
||||||
|
SockAddrV4 v4_;
|
||||||
|
SockDomain domain_{SockDomain::kV4};
|
||||||
|
|
||||||
|
public:
|
||||||
|
SockAddress() = default;
|
||||||
|
explicit SockAddress(SockAddrV6 const &addr) : v6_{addr}, domain_{SockDomain::kV6} {}
|
||||||
|
explicit SockAddress(SockAddrV4 const &addr) : v4_{addr} {}
|
||||||
|
|
||||||
|
auto Domain() const { return domain_; }
|
||||||
|
|
||||||
|
bool IsV4() const { return Domain() == SockDomain::kV4; }
|
||||||
|
bool IsV6() const { return !IsV4(); }
|
||||||
|
|
||||||
|
auto const &V4() const { return v4_; }
|
||||||
|
auto const &V6() const { return v6_; }
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief TCP socket for simple communication.
|
||||||
|
*/
|
||||||
|
class TCPSocket {
|
||||||
|
public:
|
||||||
|
using HandleT = system::SocketT;
|
||||||
|
|
||||||
|
private:
|
||||||
|
HandleT handle_{InvalidSocket()};
|
||||||
|
// There's reliable no way to extract domain from a socket without first binding that
|
||||||
|
// socket on macos.
|
||||||
|
#if defined(__APPLE__)
|
||||||
|
SockDomain domain_{SockDomain::kV4};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
constexpr static HandleT InvalidSocket() { return -1; }
|
||||||
|
|
||||||
|
explicit TCPSocket(HandleT newfd) : handle_{newfd} {}
|
||||||
|
|
||||||
|
public:
|
||||||
|
TCPSocket() = default;
|
||||||
|
/**
|
||||||
|
* \brief Return the socket domain.
|
||||||
|
*/
|
||||||
|
auto Domain() const -> SockDomain {
|
||||||
|
auto ret_iafamily = [](std::int32_t domain) {
|
||||||
|
switch (domain) {
|
||||||
|
case AF_INET:
|
||||||
|
return SockDomain::kV4;
|
||||||
|
case AF_INET6:
|
||||||
|
return SockDomain::kV6;
|
||||||
|
default: {
|
||||||
|
LOG(FATAL) << "Unknown IA family.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SockDomain::kV4;
|
||||||
|
};
|
||||||
|
|
||||||
|
#if defined(_WIN32)
|
||||||
|
WSAPROTOCOL_INFOA info;
|
||||||
|
socklen_t len = sizeof(info);
|
||||||
|
xgboost_CHECK_SYS_CALL(
|
||||||
|
getsockopt(handle_, SOL_SOCKET, SO_PROTOCOL_INFO, reinterpret_cast<char *>(&info), &len),
|
||||||
|
0);
|
||||||
|
return ret_iafamily(info.iAddressFamily);
|
||||||
|
#elif defined(__APPLE__)
|
||||||
|
return domain_;
|
||||||
|
#elif defined(__unix__)
|
||||||
|
std::int32_t domain;
|
||||||
|
socklen_t len = sizeof(domain);
|
||||||
|
xgboost_CHECK_SYS_CALL(
|
||||||
|
getsockopt(handle_, SOL_SOCKET, SO_DOMAIN, reinterpret_cast<char *>(&domain), &len), 0);
|
||||||
|
return ret_iafamily(domain);
|
||||||
|
#else
|
||||||
|
LOG(FATAL) << "Unknown platform.";
|
||||||
|
return ret_iafamily(AF_INET);
|
||||||
|
#endif // platforms
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsClosed() const { return handle_ == InvalidSocket(); }
|
||||||
|
|
||||||
|
/** \brief get last error code if any */
|
||||||
|
std::int32_t GetSockError() const {
|
||||||
|
std::int32_t error = 0;
|
||||||
|
socklen_t len = sizeof(error);
|
||||||
|
xgboost_CHECK_SYS_CALL(
|
||||||
|
getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&error), &len), 0);
|
||||||
|
return error;
|
||||||
|
}
|
||||||
|
/** \brief check if anything bad happens */
|
||||||
|
bool BadSocket() const {
|
||||||
|
if (IsClosed()) return true;
|
||||||
|
std::int32_t err = GetSockError();
|
||||||
|
if (err == EBADF || err == EINTR) return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetNonBlock() {
|
||||||
|
bool non_block{true};
|
||||||
|
#if defined(_WIN32)
|
||||||
|
u_long mode = non_block ? 1 : 0;
|
||||||
|
xgboost_CHECK_SYS_CALL(ioctlsocket(handle_, FIONBIO, &mode), NO_ERROR);
|
||||||
|
#else
|
||||||
|
std::int32_t flag = fcntl(handle_, F_GETFL, 0);
|
||||||
|
if (flag == -1) {
|
||||||
|
system::ThrowAtError("fcntl");
|
||||||
|
}
|
||||||
|
if (non_block) {
|
||||||
|
flag |= O_NONBLOCK;
|
||||||
|
} else {
|
||||||
|
flag &= ~O_NONBLOCK;
|
||||||
|
}
|
||||||
|
if (fcntl(handle_, F_SETFL, flag) == -1) {
|
||||||
|
system::ThrowAtError("fcntl");
|
||||||
|
}
|
||||||
|
#endif // _WIN32
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetKeepAlive() {
|
||||||
|
std::int32_t keepalive = 1;
|
||||||
|
xgboost_CHECK_SYS_CALL(setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE,
|
||||||
|
reinterpret_cast<char *>(&keepalive), sizeof(keepalive)),
|
||||||
|
0);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetNoDelay() {
|
||||||
|
std::int32_t tcp_no_delay = 1;
|
||||||
|
xgboost_CHECK_SYS_CALL(
|
||||||
|
setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
|
||||||
|
sizeof(tcp_no_delay)),
|
||||||
|
0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Accept new connection, returns a new TCP socket for the new connection.
|
||||||
|
*/
|
||||||
|
TCPSocket Accept() {
|
||||||
|
HandleT newfd = accept(handle_, nullptr, nullptr);
|
||||||
|
if (newfd == InvalidSocket()) {
|
||||||
|
system::ThrowAtError("accept");
|
||||||
|
}
|
||||||
|
TCPSocket newsock{newfd};
|
||||||
|
return newsock;
|
||||||
|
}
|
||||||
|
|
||||||
|
~TCPSocket() {
|
||||||
|
if (!IsClosed()) {
|
||||||
|
Close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket(TCPSocket const &that) = delete;
|
||||||
|
TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); }
|
||||||
|
TCPSocket &operator=(TCPSocket const &that) = delete;
|
||||||
|
TCPSocket &operator=(TCPSocket &&that) {
|
||||||
|
std::swap(this->handle_, that.handle_);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* \brief Return the native socket file descriptor.
|
||||||
|
*/
|
||||||
|
HandleT const &Handle() const { return handle_; }
|
||||||
|
/**
|
||||||
|
* \brief Listen to incoming requests. Should be called after bind.
|
||||||
|
*/
|
||||||
|
void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); }
|
||||||
|
/**
|
||||||
|
* \brief Bind socket to INADDR_ANY, return the port selected by the OS.
|
||||||
|
*/
|
||||||
|
in_port_t BindHost() {
|
||||||
|
if (Domain() == SockDomain::kV6) {
|
||||||
|
auto addr = SockAddrV6::InaddrAny();
|
||||||
|
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
|
||||||
|
xgboost_CHECK_SYS_CALL(
|
||||||
|
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
|
||||||
|
|
||||||
|
sockaddr_in6 res_addr;
|
||||||
|
socklen_t addrlen = sizeof(res_addr);
|
||||||
|
xgboost_CHECK_SYS_CALL(
|
||||||
|
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
|
||||||
|
return ntohs(res_addr.sin6_port);
|
||||||
|
} else {
|
||||||
|
auto addr = SockAddrV4::InaddrAny();
|
||||||
|
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
|
||||||
|
xgboost_CHECK_SYS_CALL(
|
||||||
|
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
|
||||||
|
|
||||||
|
sockaddr_in res_addr;
|
||||||
|
socklen_t addrlen = sizeof(res_addr);
|
||||||
|
xgboost_CHECK_SYS_CALL(
|
||||||
|
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
|
||||||
|
return ntohs(res_addr.sin_port);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* \brief Send data, without error then all data should be sent.
|
||||||
|
*/
|
||||||
|
auto SendAll(void const *buf, std::size_t len) {
|
||||||
|
char const *_buf = reinterpret_cast<const char *>(buf);
|
||||||
|
std::size_t ndone = 0;
|
||||||
|
while (ndone < len) {
|
||||||
|
ssize_t ret = send(handle_, _buf, len - ndone, 0);
|
||||||
|
if (ret == -1) {
|
||||||
|
if (system::LastErrorWouldBlock()) {
|
||||||
|
return ndone;
|
||||||
|
}
|
||||||
|
system::ThrowAtError("send");
|
||||||
|
}
|
||||||
|
_buf += ret;
|
||||||
|
ndone += ret;
|
||||||
|
}
|
||||||
|
return ndone;
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* \brief Receive data, without error then all data should be received.
|
||||||
|
*/
|
||||||
|
auto RecvAll(void *buf, std::size_t len) {
|
||||||
|
char *_buf = reinterpret_cast<char *>(buf);
|
||||||
|
std::size_t ndone = 0;
|
||||||
|
while (ndone < len) {
|
||||||
|
ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
|
||||||
|
if (ret == -1) {
|
||||||
|
if (system::LastErrorWouldBlock()) {
|
||||||
|
return ndone;
|
||||||
|
}
|
||||||
|
system::ThrowAtError("recv");
|
||||||
|
}
|
||||||
|
if (ret == 0) {
|
||||||
|
return ndone;
|
||||||
|
}
|
||||||
|
_buf += ret;
|
||||||
|
ndone += ret;
|
||||||
|
}
|
||||||
|
return ndone;
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* \brief Send data using the socket
|
||||||
|
* \param buf the pointer to the buffer
|
||||||
|
* \param len the size of the buffer
|
||||||
|
* \param flags extra flags
|
||||||
|
* \return size of data actually sent return -1 if error occurs
|
||||||
|
*/
|
||||||
|
auto Send(const void *buf_, std::size_t len, std::int32_t flags = 0) {
|
||||||
|
const char *buf = reinterpret_cast<const char *>(buf_);
|
||||||
|
return send(handle_, buf, len, flags);
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* \brief receive data using the socket
|
||||||
|
* \param buf the pointer to the buffer
|
||||||
|
* \param len the size of the buffer
|
||||||
|
* \param flags extra flags
|
||||||
|
* \return size of data actually received return -1 if error occurs
|
||||||
|
*/
|
||||||
|
auto Recv(void *buf, std::size_t len, std::int32_t flags = 0) {
|
||||||
|
char *_buf = reinterpret_cast<char *>(buf);
|
||||||
|
return recv(handle_, _buf, len, flags);
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* \brief Send string, format is matched with the Python socket wrapper in RABIT.
|
||||||
|
*/
|
||||||
|
std::size_t Send(StringView str);
|
||||||
|
/**
|
||||||
|
* \brief Receive string, format is matched with the Python socket wrapper in RABIT.
|
||||||
|
*/
|
||||||
|
std::size_t Recv(std::string *p_str);
|
||||||
|
/**
|
||||||
|
* \brief Close the socket, called automatically in destructor if the socket is not closed.
|
||||||
|
*/
|
||||||
|
void Close() {
|
||||||
|
if (InvalidSocket() != handle_) {
|
||||||
|
xgboost_CHECK_SYS_CALL(system::CloseSocket(handle_), 0);
|
||||||
|
handle_ = InvalidSocket();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* \brief Create a TCP socket on specified domain.
|
||||||
|
*/
|
||||||
|
static TCPSocket Create(SockDomain domain) {
|
||||||
|
#if xgboost_IS_MINGW()
|
||||||
|
MingWError();
|
||||||
|
return {};
|
||||||
|
#else
|
||||||
|
auto fd = socket(static_cast<std::int32_t>(domain), SOCK_STREAM, 0);
|
||||||
|
if (fd == InvalidSocket()) {
|
||||||
|
system::ThrowAtError("socket");
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket socket{fd};
|
||||||
|
#if defined(__APPLE__)
|
||||||
|
socket.domain_ = domain;
|
||||||
|
#endif // defined(__APPLE__)
|
||||||
|
return socket;
|
||||||
|
#endif // xgboost_IS_MINGW()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Connect to remote address, returns the error code if failed (no exception is
|
||||||
|
* raised so that we can retry).
|
||||||
|
*/
|
||||||
|
std::error_code Connect(SockAddress const &addr, TCPSocket *out);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Get the local host name.
|
||||||
|
*/
|
||||||
|
inline std::string GetHostName() {
|
||||||
|
char buf[HOST_NAME_MAX];
|
||||||
|
xgboost_CHECK_SYS_CALL(gethostname(&buf[0], HOST_NAME_MAX), 0);
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
|
|
||||||
|
#undef xgboost_CHECK_SYS_CALL
|
||||||
|
#undef xgboost_IS_MINGW
|
||||||
@ -52,6 +52,7 @@ from typing import (
|
|||||||
Sequence,
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
TypedDict,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
@ -102,19 +103,13 @@ else:
|
|||||||
|
|
||||||
_DaskCollection = Union["da.Array", "dd.DataFrame", "dd.Series"]
|
_DaskCollection = Union["da.Array", "dd.DataFrame", "dd.Series"]
|
||||||
_DataT = Union["da.Array", "dd.DataFrame"] # do not use series as predictor
|
_DataT = Union["da.Array", "dd.DataFrame"] # do not use series as predictor
|
||||||
|
TrainReturnT = TypedDict(
|
||||||
try:
|
"TrainReturnT",
|
||||||
from mypy_extensions import TypedDict
|
{
|
||||||
|
"booster": Booster,
|
||||||
TrainReturnT = TypedDict(
|
"history": Dict,
|
||||||
"TrainReturnT",
|
},
|
||||||
{
|
)
|
||||||
"booster": Booster,
|
|
||||||
"history": Dict,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
TrainReturnT = Dict[str, Any] # type:ignore
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"RabitContext",
|
"RabitContext",
|
||||||
@ -832,11 +827,15 @@ async def _get_rabit_args(
|
|||||||
if k not in valid_config:
|
if k not in valid_config:
|
||||||
raise ValueError(f"Unknown configuration: {k}")
|
raise ValueError(f"Unknown configuration: {k}")
|
||||||
host_ip = dconfig.get("scheduler_address", None)
|
host_ip = dconfig.get("scheduler_address", None)
|
||||||
|
if host_ip is not None and host_ip.startswith("[") and host_ip.endswith("]"):
|
||||||
|
# convert dask bracket format to proper IPv6 address.
|
||||||
|
host_ip = host_ip[1:-1]
|
||||||
if host_ip is not None:
|
if host_ip is not None:
|
||||||
try:
|
try:
|
||||||
host_ip, port = distributed.comm.get_address_host_port(host_ip)
|
host_ip, port = distributed.comm.get_address_host_port(host_ip)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if host_ip is not None:
|
if host_ip is not None:
|
||||||
user_addr = (host_ip, port)
|
user_addr = (host_ip, port)
|
||||||
else:
|
else:
|
||||||
|
|||||||
41
python-package/xgboost/testing.py
Normal file
41
python-package/xgboost/testing.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
"""Utilities for defining Python tests."""
|
||||||
|
|
||||||
|
import socket
|
||||||
|
from platform import system
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
|
PytestSkip = TypedDict("PytestSkip", {"condition": bool, "reason": str})
|
||||||
|
|
||||||
|
|
||||||
|
def has_ipv6() -> bool:
|
||||||
|
"""Check whether IPv6 is enabled on this host."""
|
||||||
|
# connection error in macos, still need some fixes.
|
||||||
|
if system() not in ("Linux", "Windows"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if socket.has_ipv6:
|
||||||
|
try:
|
||||||
|
with socket.socket(
|
||||||
|
socket.AF_INET6, socket.SOCK_STREAM
|
||||||
|
) as server, socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as client:
|
||||||
|
server.bind(("::1", 0))
|
||||||
|
port = server.getsockname()[1]
|
||||||
|
server.listen()
|
||||||
|
|
||||||
|
client.connect(("::1", port))
|
||||||
|
conn, _ = server.accept()
|
||||||
|
|
||||||
|
client.sendall("abc".encode())
|
||||||
|
msg = conn.recv(3).decode()
|
||||||
|
# if the code can be executed to this point, the message should be
|
||||||
|
# correct.
|
||||||
|
assert msg == "abc"
|
||||||
|
return True
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def skip_ipv6() -> PytestSkip:
|
||||||
|
"""PyTest skip mark for IPv6."""
|
||||||
|
return {"condition": not has_ipv6(), "reason": "IPv6 is required to be enabled."}
|
||||||
@ -112,7 +112,7 @@ class WorkerEntry:
|
|||||||
"""Assign the rank for current entry."""
|
"""Assign the rank for current entry."""
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
nnset = set(tree_map[rank])
|
nnset = set(tree_map[rank])
|
||||||
rprev, rnext = ring_map[rank]
|
rprev, next_rank = ring_map[rank]
|
||||||
self.sock.sendint(rank)
|
self.sock.sendint(rank)
|
||||||
# send parent rank
|
# send parent rank
|
||||||
self.sock.sendint(parent_map[rank])
|
self.sock.sendint(parent_map[rank])
|
||||||
@ -129,9 +129,9 @@ class WorkerEntry:
|
|||||||
else:
|
else:
|
||||||
self.sock.sendint(-1)
|
self.sock.sendint(-1)
|
||||||
# send next link
|
# send next link
|
||||||
if rnext not in (-1, rank):
|
if next_rank not in (-1, rank):
|
||||||
nnset.add(rnext)
|
nnset.add(next_rank)
|
||||||
self.sock.sendint(rnext)
|
self.sock.sendint(next_rank)
|
||||||
else:
|
else:
|
||||||
self.sock.sendint(-1)
|
self.sock.sendint(-1)
|
||||||
|
|
||||||
@ -157,6 +157,7 @@ class WorkerEntry:
|
|||||||
self.sock.sendstr(wait_conn[r].host)
|
self.sock.sendstr(wait_conn[r].host)
|
||||||
port = wait_conn[r].port
|
port = wait_conn[r].port
|
||||||
assert port is not None
|
assert port is not None
|
||||||
|
# send port of this node to other workers so that they can call connect
|
||||||
self.sock.sendint(port)
|
self.sock.sendint(port)
|
||||||
self.sock.sendint(r)
|
self.sock.sendint(r)
|
||||||
nerr = self.sock.recvint()
|
nerr = self.sock.recvint()
|
||||||
|
|||||||
@ -1,65 +1,48 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright (c) 2014-2019 by Contributors
|
* Copyright (c) 2014-2022 by XGBoost Contributors
|
||||||
* \file socket.h
|
* \file socket.h
|
||||||
* \brief this file aims to provide a wrapper of sockets
|
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#ifndef RABIT_INTERNAL_SOCKET_H_
|
#ifndef RABIT_INTERNAL_SOCKET_H_
|
||||||
#define RABIT_INTERNAL_SOCKET_H_
|
#define RABIT_INTERNAL_SOCKET_H_
|
||||||
|
#include "xgboost/collective/socket.h"
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
#include <winsock2.h>
|
#include <winsock2.h>
|
||||||
#include <ws2tcpip.h>
|
#include <ws2tcpip.h>
|
||||||
|
|
||||||
#ifdef _MSC_VER
|
|
||||||
#pragma comment(lib, "Ws2_32.lib")
|
|
||||||
#endif // _MSC_VER
|
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
|
#include <arpa/inet.h>
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
#include <netdb.h>
|
#include <netdb.h>
|
||||||
#include <cerrno>
|
|
||||||
#include <unistd.h>
|
|
||||||
#include <arpa/inet.h>
|
|
||||||
#include <netinet/in.h>
|
#include <netinet/in.h>
|
||||||
#include <sys/socket.h>
|
|
||||||
#include <sys/ioctl.h>
|
#include <sys/ioctl.h>
|
||||||
|
#include <sys/socket.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
|
||||||
#if defined(__sun) || defined(sun)
|
#include <cerrno>
|
||||||
#include <sys/sockio.h>
|
|
||||||
#endif // defined(__sun) || defined(sun)
|
|
||||||
|
|
||||||
#endif // defined(_WIN32)
|
#endif // defined(_WIN32)
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <cstring>
|
|
||||||
#include <vector>
|
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
|
#include <cstring>
|
||||||
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
|
|
||||||
#if defined(_WIN32) && !defined(__MINGW32__)
|
#if !defined(_WIN32)
|
||||||
typedef int ssize_t;
|
|
||||||
#endif // defined(_WIN32) || defined(__MINGW32__)
|
|
||||||
|
|
||||||
#if defined(_WIN32)
|
|
||||||
using sock_size_t = int;
|
|
||||||
|
|
||||||
#else
|
|
||||||
|
|
||||||
#include <sys/poll.h>
|
#include <sys/poll.h>
|
||||||
|
|
||||||
using SOCKET = int;
|
using SOCKET = int;
|
||||||
using sock_size_t = size_t; // NOLINT
|
using sock_size_t = size_t; // NOLINT
|
||||||
#endif // defined(_WIN32)
|
#endif // !defined(_WIN32)
|
||||||
|
|
||||||
#define IS_MINGW() defined(__MINGW32__)
|
#define IS_MINGW() defined(__MINGW32__)
|
||||||
|
|
||||||
#if IS_MINGW()
|
|
||||||
inline void MingWError() {
|
|
||||||
throw dmlc::Error("Distributed training on mingw is not supported.");
|
|
||||||
}
|
|
||||||
#endif // IS_MINGW()
|
|
||||||
|
|
||||||
#if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
|
#if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
|
||||||
/*
|
/*
|
||||||
* On later mingw versions poll should be supported (with bugs). See:
|
* On later mingw versions poll should be supported (with bugs). See:
|
||||||
@ -88,23 +71,17 @@ typedef struct pollfd {
|
|||||||
// POLLWRNORM
|
// POLLWRNORM
|
||||||
#define POLLOUT 0x0010
|
#define POLLOUT 0x0010
|
||||||
|
|
||||||
inline const char *inet_ntop(int, const void *, char *, size_t) {
|
|
||||||
MingWError();
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
#endif // IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
|
#endif // IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace utils {
|
namespace utils {
|
||||||
|
|
||||||
static constexpr int kInvalidSocket = -1;
|
|
||||||
|
|
||||||
template <typename PollFD>
|
template <typename PollFD>
|
||||||
int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) {
|
int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) {
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
|
|
||||||
#if IS_MINGW()
|
#if IS_MINGW()
|
||||||
MingWError();
|
xgboost::MingWError();
|
||||||
return -1;
|
return -1;
|
||||||
#else
|
#else
|
||||||
return WSAPoll(pfd, nfds, std::chrono::milliseconds(timeout).count());
|
return WSAPoll(pfd, nfds, std::chrono::milliseconds(timeout).count());
|
||||||
@ -115,458 +92,6 @@ int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) {
|
|||||||
#endif // IS_MINGW()
|
#endif // IS_MINGW()
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief data structure for network address */
|
|
||||||
struct SockAddr {
|
|
||||||
sockaddr_in addr;
|
|
||||||
// constructor
|
|
||||||
SockAddr() = default;
|
|
||||||
SockAddr(const char *url, int port) {
|
|
||||||
this->Set(url, port);
|
|
||||||
}
|
|
||||||
inline static std::string GetHostName() {
|
|
||||||
std::string buf; buf.resize(256);
|
|
||||||
#if !IS_MINGW()
|
|
||||||
utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
|
|
||||||
#endif // IS_MINGW()
|
|
||||||
return std::string(buf.c_str());
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief set the address
|
|
||||||
* \param url the url of the address
|
|
||||||
* \param port the port of address
|
|
||||||
*/
|
|
||||||
inline void Set(const char *host, int port) {
|
|
||||||
#if !IS_MINGW()
|
|
||||||
addrinfo hints;
|
|
||||||
memset(&hints, 0, sizeof(hints));
|
|
||||||
hints.ai_family = AF_INET;
|
|
||||||
hints.ai_protocol = SOCK_STREAM;
|
|
||||||
addrinfo *res = nullptr;
|
|
||||||
int sig = getaddrinfo(host, nullptr, &hints, &res);
|
|
||||||
Check(sig == 0 && res != nullptr, "cannot obtain address of %s", host);
|
|
||||||
Check(res->ai_family == AF_INET, "Does not support IPv6");
|
|
||||||
memcpy(&addr, res->ai_addr, res->ai_addrlen);
|
|
||||||
addr.sin_port = htons(port);
|
|
||||||
freeaddrinfo(res);
|
|
||||||
#endif // !IS_MINGW()
|
|
||||||
}
|
|
||||||
/*! \brief return port of the address*/
|
|
||||||
inline int Port() const {
|
|
||||||
return ntohs(addr.sin_port);
|
|
||||||
}
|
|
||||||
/*! \return a string representation of the address */
|
|
||||||
inline std::string AddrStr() const {
|
|
||||||
std::string buf; buf.resize(256);
|
|
||||||
#ifdef _WIN32
|
|
||||||
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
|
|
||||||
&buf[0], buf.length());
|
|
||||||
#else
|
|
||||||
const char *s = inet_ntop(AF_INET, &addr.sin_addr,
|
|
||||||
&buf[0], buf.length());
|
|
||||||
#endif // _WIN32
|
|
||||||
Assert(s != nullptr, "cannot decode address");
|
|
||||||
return std::string(s);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief base class containing common operations of TCP and UDP sockets
|
|
||||||
*/
|
|
||||||
class Socket {
|
|
||||||
public:
|
|
||||||
/*! \brief the file descriptor of socket */
|
|
||||||
SOCKET sockfd;
|
|
||||||
// default conversion to int
|
|
||||||
operator SOCKET() const { // NOLINT
|
|
||||||
return sockfd;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \return last error of socket operation
|
|
||||||
*/
|
|
||||||
inline static int GetLastError() {
|
|
||||||
#ifdef _WIN32
|
|
||||||
|
|
||||||
#if IS_MINGW()
|
|
||||||
MingWError();
|
|
||||||
return -1;
|
|
||||||
#else
|
|
||||||
return WSAGetLastError();
|
|
||||||
#endif // IS_MINGW()
|
|
||||||
|
|
||||||
#else
|
|
||||||
return errno;
|
|
||||||
#endif // _WIN32
|
|
||||||
}
|
|
||||||
/*! \return whether last error was would block */
|
|
||||||
inline static bool LastErrorWouldBlock() {
|
|
||||||
int errsv = GetLastError();
|
|
||||||
#ifdef _WIN32
|
|
||||||
return errsv == WSAEWOULDBLOCK;
|
|
||||||
#else
|
|
||||||
return errsv == EAGAIN || errsv == EWOULDBLOCK;
|
|
||||||
#endif // _WIN32
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief start up the socket module
|
|
||||||
* call this before using the sockets
|
|
||||||
*/
|
|
||||||
inline static void Startup() {
|
|
||||||
#ifdef _WIN32
|
|
||||||
#if !IS_MINGW()
|
|
||||||
WSADATA wsa_data;
|
|
||||||
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
|
|
||||||
Socket::Error("Startup");
|
|
||||||
}
|
|
||||||
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
|
|
||||||
WSACleanup();
|
|
||||||
utils::Error("Could not find a usable version of Winsock.dll\n");
|
|
||||||
}
|
|
||||||
#endif // !IS_MINGW()
|
|
||||||
#endif // _WIN32
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief shutdown the socket module after use, all sockets need to be closed
|
|
||||||
*/
|
|
||||||
inline static void Finalize() {
|
|
||||||
#ifdef _WIN32
|
|
||||||
#if !IS_MINGW()
|
|
||||||
WSACleanup();
|
|
||||||
#endif // !IS_MINGW()
|
|
||||||
#endif // _WIN32
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief set this socket to use non-blocking mode
|
|
||||||
* \param non_block whether set it to be non-block, if it is false
|
|
||||||
* it will set it back to block mode
|
|
||||||
*/
|
|
||||||
inline void SetNonBlock(bool non_block) {
|
|
||||||
#ifdef _WIN32
|
|
||||||
#if !IS_MINGW()
|
|
||||||
u_long mode = non_block ? 1 : 0;
|
|
||||||
if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
|
|
||||||
Socket::Error("SetNonBlock");
|
|
||||||
}
|
|
||||||
#endif // !IS_MINGW()
|
|
||||||
#else
|
|
||||||
int flag = fcntl(sockfd, F_GETFL, 0);
|
|
||||||
if (flag == -1) {
|
|
||||||
Socket::Error("SetNonBlock-1");
|
|
||||||
}
|
|
||||||
if (non_block) {
|
|
||||||
flag |= O_NONBLOCK;
|
|
||||||
} else {
|
|
||||||
flag &= ~O_NONBLOCK;
|
|
||||||
}
|
|
||||||
if (fcntl(sockfd, F_SETFL, flag) == -1) {
|
|
||||||
Socket::Error("SetNonBlock-2");
|
|
||||||
}
|
|
||||||
#endif // _WIN32
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief bind the socket to an address
|
|
||||||
* \param addr
|
|
||||||
*/
|
|
||||||
inline void Bind(const SockAddr &addr) {
|
|
||||||
#if !IS_MINGW()
|
|
||||||
if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
|
|
||||||
sizeof(addr.addr)) == -1) {
|
|
||||||
Socket::Error("Bind");
|
|
||||||
}
|
|
||||||
#endif // !IS_MINGW()
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief try bind the socket to host, from start_port to end_port
|
|
||||||
* \param start_port starting port number to try
|
|
||||||
* \param end_port ending port number to try
|
|
||||||
* \return the port successfully bind to, return -1 if failed to bind any port
|
|
||||||
*/
|
|
||||||
inline int TryBindHost(int start_port, int end_port) {
|
|
||||||
// TODO(tqchen) add prefix check
|
|
||||||
#if !IS_MINGW()
|
|
||||||
for (int port = start_port; port < end_port; ++port) {
|
|
||||||
SockAddr addr("0.0.0.0", port);
|
|
||||||
if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
|
|
||||||
sizeof(addr.addr)) == 0) {
|
|
||||||
return port;
|
|
||||||
}
|
|
||||||
#if defined(_WIN32)
|
|
||||||
if (WSAGetLastError() != WSAEADDRINUSE) {
|
|
||||||
Socket::Error("TryBindHost");
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
if (errno != EADDRINUSE) {
|
|
||||||
Socket::Error("TryBindHost");
|
|
||||||
}
|
|
||||||
#endif // defined(_WIN32)
|
|
||||||
}
|
|
||||||
#endif // !IS_MINGW()
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
/*! \brief get last error code if any */
|
|
||||||
inline int GetSockError() const {
|
|
||||||
int error = 0;
|
|
||||||
socklen_t len = sizeof(error);
|
|
||||||
#if !IS_MINGW()
|
|
||||||
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
|
|
||||||
reinterpret_cast<char *>(&error), &len) != 0) {
|
|
||||||
Error("GetSockError");
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
// undefined reference to `_imp__getsockopt@20'
|
|
||||||
MingWError();
|
|
||||||
#endif // !IS_MINGW()
|
|
||||||
return error;
|
|
||||||
}
|
|
||||||
/*! \brief check if anything bad happens */
|
|
||||||
inline bool BadSocket() const {
|
|
||||||
if (IsClosed()) return true;
|
|
||||||
int err = GetSockError();
|
|
||||||
if (err == EBADF || err == EINTR) return true;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
/*! \brief check if socket is already closed */
|
|
||||||
inline bool IsClosed() const {
|
|
||||||
return sockfd == kInvalidSocket;
|
|
||||||
}
|
|
||||||
/*! \brief close the socket */
|
|
||||||
inline void Close() {
|
|
||||||
if (sockfd != kInvalidSocket) {
|
|
||||||
#ifdef _WIN32
|
|
||||||
#if !IS_MINGW()
|
|
||||||
closesocket(sockfd);
|
|
||||||
#endif // !IS_MINGW()
|
|
||||||
#else
|
|
||||||
close(sockfd);
|
|
||||||
#endif
|
|
||||||
sockfd = kInvalidSocket;
|
|
||||||
} else {
|
|
||||||
Error("Socket::Close double close the socket or close without create");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// report an socket error
|
|
||||||
inline static void Error(const char *msg) {
|
|
||||||
int errsv = GetLastError();
|
|
||||||
#ifdef _WIN32
|
|
||||||
utils::Error("Socket %s Error:WSAError-code=%d", msg, errsv);
|
|
||||||
#else
|
|
||||||
utils::Error("Socket %s Error:%s", msg, strerror(errsv));
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
explicit Socket(SOCKET sockfd) : sockfd(sockfd) {
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief a wrapper of TCP socket that hopefully be cross platform
|
|
||||||
*/
|
|
||||||
class TCPSocket : public Socket{
|
|
||||||
public:
|
|
||||||
// constructor
|
|
||||||
TCPSocket() : Socket(kInvalidSocket) {
|
|
||||||
}
|
|
||||||
explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) {
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief enable/disable TCP keepalive
|
|
||||||
* \param keepalive whether to set the keep alive option on
|
|
||||||
*/
|
|
||||||
void SetKeepAlive(bool keepalive) {
|
|
||||||
#if !IS_MINGW()
|
|
||||||
int opt = static_cast<int>(keepalive);
|
|
||||||
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE,
|
|
||||||
reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
|
|
||||||
Socket::Error("SetKeepAlive");
|
|
||||||
}
|
|
||||||
#endif // !IS_MINGW()
|
|
||||||
}
|
|
||||||
inline void SetLinger(int timeout = 0) {
|
|
||||||
#if !IS_MINGW()
|
|
||||||
struct linger sl;
|
|
||||||
sl.l_onoff = 1; /* non-zero value enables linger option in kernel */
|
|
||||||
sl.l_linger = timeout; /* timeout interval in seconds */
|
|
||||||
if (setsockopt(sockfd, SOL_SOCKET, SO_LINGER, reinterpret_cast<char*>(&sl), sizeof(sl)) == -1) {
|
|
||||||
Socket::Error("SO_LINGER");
|
|
||||||
}
|
|
||||||
#endif // !IS_MINGW()
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief create the socket, call this before using socket
|
|
||||||
* \param af domain
|
|
||||||
*/
|
|
||||||
inline void Create(int af = PF_INET) {
|
|
||||||
#if !IS_MINGW()
|
|
||||||
sockfd = socket(af, SOCK_STREAM, 0);
|
|
||||||
if (sockfd == kInvalidSocket) {
|
|
||||||
Socket::Error("Create");
|
|
||||||
}
|
|
||||||
#endif // !IS_MINGW()
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief perform listen of the socket
|
|
||||||
* \param backlog backlog parameter
|
|
||||||
*/
|
|
||||||
inline void Listen(int backlog = 16) {
|
|
||||||
#if !IS_MINGW()
|
|
||||||
listen(sockfd, backlog);
|
|
||||||
#endif // !IS_MINGW()
|
|
||||||
}
|
|
||||||
/*! \brief get a new connection */
|
|
||||||
TCPSocket Accept() {
|
|
||||||
#if !IS_MINGW()
|
|
||||||
SOCKET newfd = accept(sockfd, nullptr, nullptr);
|
|
||||||
if (newfd == kInvalidSocket) {
|
|
||||||
Socket::Error("Accept");
|
|
||||||
}
|
|
||||||
return TCPSocket(newfd);
|
|
||||||
#else
|
|
||||||
return TCPSocket();
|
|
||||||
#endif // !IS_MINGW()
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief decide whether the socket is at OOB mark
|
|
||||||
* \return 1 if at mark, 0 if not, -1 if an error occured
|
|
||||||
*/
|
|
||||||
inline int AtMark() const {
|
|
||||||
#if !IS_MINGW()
|
|
||||||
|
|
||||||
#ifdef _WIN32
|
|
||||||
unsigned long atmark; // NOLINT(*)
|
|
||||||
if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
|
|
||||||
#else
|
|
||||||
int atmark;
|
|
||||||
if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1;
|
|
||||||
#endif // _WIN32
|
|
||||||
|
|
||||||
return static_cast<int>(atmark);
|
|
||||||
|
|
||||||
#else
|
|
||||||
return -1;
|
|
||||||
#endif // !IS_MINGW()
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief connect to an address
|
|
||||||
* \param addr the address to connect to
|
|
||||||
* \return whether connect is successful
|
|
||||||
*/
|
|
||||||
inline bool Connect(const SockAddr &addr) {
|
|
||||||
#if !IS_MINGW()
|
|
||||||
return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
|
|
||||||
sizeof(addr.addr)) == 0;
|
|
||||||
#else
|
|
||||||
return false;
|
|
||||||
#endif // !IS_MINGW()
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief send data using the socket
|
|
||||||
* \param buf the pointer to the buffer
|
|
||||||
* \param len the size of the buffer
|
|
||||||
* \param flags extra flags
|
|
||||||
* \return size of data actually sent
|
|
||||||
* return -1 if error occurs
|
|
||||||
*/
|
|
||||||
inline ssize_t Send(const void *buf_, size_t len, int flag = 0) {
|
|
||||||
const char *buf = reinterpret_cast<const char*>(buf_);
|
|
||||||
#if !IS_MINGW()
|
|
||||||
return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
|
|
||||||
#else
|
|
||||||
return 0;
|
|
||||||
#endif // !IS_MINGW()
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief receive data using the socket
|
|
||||||
* \param buf_ the pointer to the buffer
|
|
||||||
* \param len the size of the buffer
|
|
||||||
* \param flags extra flags
|
|
||||||
* \return size of data actually received
|
|
||||||
* return -1 if error occurs
|
|
||||||
*/
|
|
||||||
inline ssize_t Recv(void *buf_, size_t len, int flags = 0) {
|
|
||||||
char *buf = reinterpret_cast<char*>(buf_);
|
|
||||||
#if !IS_MINGW()
|
|
||||||
return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
|
|
||||||
#else
|
|
||||||
return 0;
|
|
||||||
#endif // !IS_MINGW()
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief peform block write that will attempt to send all data out
|
|
||||||
* can still return smaller than request when error occurs
|
|
||||||
* \param buf the pointer to the buffer
|
|
||||||
* \param len the size of the buffer
|
|
||||||
* \return size of data actually sent
|
|
||||||
*/
|
|
||||||
inline size_t SendAll(const void *buf_, size_t len) {
|
|
||||||
const char *buf = reinterpret_cast<const char*>(buf_);
|
|
||||||
size_t ndone = 0;
|
|
||||||
#if !IS_MINGW()
|
|
||||||
while (ndone < len) {
|
|
||||||
ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0);
|
|
||||||
if (ret == -1) {
|
|
||||||
if (LastErrorWouldBlock()) return ndone;
|
|
||||||
Socket::Error("SendAll");
|
|
||||||
}
|
|
||||||
buf += ret;
|
|
||||||
ndone += ret;
|
|
||||||
}
|
|
||||||
#endif // !IS_MINGW()
|
|
||||||
return ndone;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief peforma block read that will attempt to read all data
|
|
||||||
* can still return smaller than request when error occurs
|
|
||||||
* \param buf_ the buffer pointer
|
|
||||||
* \param len length of data to recv
|
|
||||||
* \return size of data actually sent
|
|
||||||
*/
|
|
||||||
inline size_t RecvAll(void *buf_, size_t len) {
|
|
||||||
char *buf = reinterpret_cast<char*>(buf_);
|
|
||||||
size_t ndone = 0;
|
|
||||||
#if !IS_MINGW()
|
|
||||||
while (ndone < len) {
|
|
||||||
ssize_t ret = recv(sockfd, buf,
|
|
||||||
static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
|
|
||||||
if (ret == -1) {
|
|
||||||
if (LastErrorWouldBlock()) return ndone;
|
|
||||||
Socket::Error("RecvAll");
|
|
||||||
}
|
|
||||||
if (ret == 0) return ndone;
|
|
||||||
buf += ret;
|
|
||||||
ndone += ret;
|
|
||||||
}
|
|
||||||
#endif // !IS_MINGW()
|
|
||||||
return ndone;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief send a string over network
|
|
||||||
* \param str the string to be sent
|
|
||||||
*/
|
|
||||||
inline void SendStr(const std::string &str) {
|
|
||||||
int len = static_cast<int>(str.length());
|
|
||||||
utils::Assert(this->SendAll(&len, sizeof(len)) == sizeof(len),
|
|
||||||
"error during send SendStr");
|
|
||||||
if (len != 0) {
|
|
||||||
utils::Assert(this->SendAll(str.c_str(), str.length()) == str.length(),
|
|
||||||
"error during send SendStr");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief recv a string from network
|
|
||||||
* \param out_str the string to receive
|
|
||||||
*/
|
|
||||||
inline void RecvStr(std::string *out_str) {
|
|
||||||
int len;
|
|
||||||
utils::Assert(this->RecvAll(&len, sizeof(len)) == sizeof(len),
|
|
||||||
"error during send RecvStr");
|
|
||||||
out_str->resize(len);
|
|
||||||
if (len != 0) {
|
|
||||||
utils::Assert(this->RecvAll(&(*out_str)[0], len) == out_str->length(),
|
|
||||||
"error during send SendStr");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*! \brief helper data structure to perform poll */
|
/*! \brief helper data structure to perform poll */
|
||||||
struct PollHelper {
|
struct PollHelper {
|
||||||
public:
|
public:
|
||||||
@ -579,6 +104,8 @@ struct PollHelper {
|
|||||||
pfd.fd = fd;
|
pfd.fd = fd;
|
||||||
pfd.events |= POLLIN;
|
pfd.events |= POLLIN;
|
||||||
}
|
}
|
||||||
|
void WatchRead(xgboost::collective::TCPSocket const &socket) { this->WatchRead(socket.Handle()); }
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief add file descriptor to watch for write
|
* \brief add file descriptor to watch for write
|
||||||
* \param fd file descriptor to be watched
|
* \param fd file descriptor to be watched
|
||||||
@ -588,6 +115,10 @@ struct PollHelper {
|
|||||||
pfd.fd = fd;
|
pfd.fd = fd;
|
||||||
pfd.events |= POLLOUT;
|
pfd.events |= POLLOUT;
|
||||||
}
|
}
|
||||||
|
void WatchWrite(xgboost::collective::TCPSocket const &socket) {
|
||||||
|
this->WatchWrite(socket.Handle());
|
||||||
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief add file descriptor to watch for exception
|
* \brief add file descriptor to watch for exception
|
||||||
* \param fd file descriptor to be watched
|
* \param fd file descriptor to be watched
|
||||||
@ -597,6 +128,9 @@ struct PollHelper {
|
|||||||
pfd.fd = fd;
|
pfd.fd = fd;
|
||||||
pfd.events |= POLLPRI;
|
pfd.events |= POLLPRI;
|
||||||
}
|
}
|
||||||
|
void WatchException(xgboost::collective::TCPSocket const &socket) {
|
||||||
|
this->WatchException(socket.Handle());
|
||||||
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief Check if the descriptor is ready for read
|
* \brief Check if the descriptor is ready for read
|
||||||
* \param fd file descriptor to check status
|
* \param fd file descriptor to check status
|
||||||
@ -605,6 +139,10 @@ struct PollHelper {
|
|||||||
const auto& pfd = fds.find(fd);
|
const auto& pfd = fds.find(fd);
|
||||||
return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
|
return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
|
||||||
}
|
}
|
||||||
|
bool CheckRead(xgboost::collective::TCPSocket const &socket) const {
|
||||||
|
return this->CheckRead(socket.Handle());
|
||||||
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Check if the descriptor is ready for write
|
* \brief Check if the descriptor is ready for write
|
||||||
* \param fd file descriptor to check status
|
* \param fd file descriptor to check status
|
||||||
@ -613,7 +151,9 @@ struct PollHelper {
|
|||||||
const auto& pfd = fds.find(fd);
|
const auto& pfd = fds.find(fd);
|
||||||
return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
|
return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
|
||||||
}
|
}
|
||||||
|
bool CheckWrite(xgboost::collective::TCPSocket const &socket) const {
|
||||||
|
return this->CheckWrite(socket.Handle());
|
||||||
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief perform poll on the set defined, read, write, exception
|
* \brief perform poll on the set defined, read, write, exception
|
||||||
* \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
|
* \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
|
||||||
@ -629,7 +169,7 @@ struct PollHelper {
|
|||||||
if (ret == 0) {
|
if (ret == 0) {
|
||||||
LOG(FATAL) << "Poll timeout";
|
LOG(FATAL) << "Poll timeout";
|
||||||
} else if (ret < 0) {
|
} else if (ret < 0) {
|
||||||
Socket::Error("Poll");
|
LOG(FATAL) << "Failed to poll.";
|
||||||
} else {
|
} else {
|
||||||
for (auto& pfd : fdset) {
|
for (auto& pfd : fdset) {
|
||||||
auto revents = pfd.revents & pfd.events;
|
auto revents = pfd.revents & pfd.events;
|
||||||
|
|||||||
@ -8,15 +8,17 @@
|
|||||||
#define RABIT_INTERNAL_UTILS_H_
|
#define RABIT_INTERNAL_UTILS_H_
|
||||||
|
|
||||||
#include <rabit/base.h>
|
#include <rabit/base.h>
|
||||||
#include <cstring>
|
|
||||||
|
#include <cstdarg>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <string>
|
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "dmlc/io.h"
|
#include "dmlc/io.h"
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
#include <cstdarg>
|
|
||||||
|
|
||||||
#if !defined(__GNUC__) || defined(__FreeBSD__)
|
#if !defined(__GNUC__) || defined(__FreeBSD__)
|
||||||
#define fopen64 std::fopen
|
#define fopen64 std::fopen
|
||||||
|
|||||||
@ -27,8 +27,6 @@ AllreduceBase::AllreduceBase() {
|
|||||||
tracker_uri = "NULL";
|
tracker_uri = "NULL";
|
||||||
tracker_port = 9000;
|
tracker_port = 9000;
|
||||||
host_uri = "";
|
host_uri = "";
|
||||||
slave_port = 9010;
|
|
||||||
nport_trial = 1000;
|
|
||||||
rank = 0;
|
rank = 0;
|
||||||
world_size = -1;
|
world_size = -1;
|
||||||
connect_retry = 5;
|
connect_retry = 5;
|
||||||
@ -114,16 +112,16 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
|
|||||||
this->rank = -1;
|
this->rank = -1;
|
||||||
//---------------------
|
//---------------------
|
||||||
// start socket
|
// start socket
|
||||||
utils::Socket::Startup();
|
xgboost::system::SocketStartup();
|
||||||
utils::Assert(all_links.size() == 0, "can only call Init once");
|
utils::Assert(all_links.size() == 0, "can only call Init once");
|
||||||
this->host_uri = utils::SockAddr::GetHostName();
|
this->host_uri = xgboost::collective::GetHostName();
|
||||||
// get information from tracker
|
// get information from tracker
|
||||||
return this->ReConnectLinks();
|
return this->ReConnectLinks();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AllreduceBase::Shutdown() {
|
bool AllreduceBase::Shutdown() {
|
||||||
try {
|
try {
|
||||||
for (auto & all_link : all_links) {
|
for (auto &all_link : all_links) {
|
||||||
if (!all_link.sock.IsClosed()) {
|
if (!all_link.sock.IsClosed()) {
|
||||||
all_link.sock.Close();
|
all_link.sock.Close();
|
||||||
}
|
}
|
||||||
@ -133,12 +131,12 @@ bool AllreduceBase::Shutdown() {
|
|||||||
|
|
||||||
if (tracker_uri == "NULL") return true;
|
if (tracker_uri == "NULL") return true;
|
||||||
// notify tracker rank i have shutdown
|
// notify tracker rank i have shutdown
|
||||||
utils::TCPSocket tracker = this->ConnectTracker();
|
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
|
||||||
tracker.SendStr(std::string("shutdown"));
|
tracker.Send(xgboost::StringView{"shutdown"});
|
||||||
tracker.Close();
|
tracker.Close();
|
||||||
utils::TCPSocket::Finalize();
|
xgboost::system::SocketFinalize();
|
||||||
return true;
|
return true;
|
||||||
} catch (const std::exception& e) {
|
} catch (std::exception const &e) {
|
||||||
LOG(WARNING) << "Failed to shutdown due to" << e.what();
|
LOG(WARNING) << "Failed to shutdown due to" << e.what();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -148,9 +146,9 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
|
|||||||
if (tracker_uri == "NULL") {
|
if (tracker_uri == "NULL") {
|
||||||
utils::Printf("%s", msg.c_str()); return;
|
utils::Printf("%s", msg.c_str()); return;
|
||||||
}
|
}
|
||||||
utils::TCPSocket tracker = this->ConnectTracker();
|
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
|
||||||
tracker.SendStr(std::string("print"));
|
tracker.Send(xgboost::StringView{"print"});
|
||||||
tracker.SendStr(msg);
|
tracker.Send(xgboost::StringView{msg});
|
||||||
tracker.Close();
|
tracker.Close();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -227,21 +225,23 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
* \brief initialize connection to the tracker
|
* \brief initialize connection to the tracker
|
||||||
* \return a socket that initializes the connection
|
* \return a socket that initializes the connection
|
||||||
*/
|
*/
|
||||||
utils::TCPSocket AllreduceBase::ConnectTracker() const {
|
xgboost::collective::TCPSocket AllreduceBase::ConnectTracker() const {
|
||||||
int magic = kMagic;
|
int magic = kMagic;
|
||||||
// get information from tracker
|
// get information from tracker
|
||||||
utils::TCPSocket tracker;
|
xgboost::collective::TCPSocket tracker;
|
||||||
tracker.Create();
|
|
||||||
|
|
||||||
int retry = 0;
|
int retry = 0;
|
||||||
do {
|
do {
|
||||||
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
|
auto rc = xgboost::collective::Connect(
|
||||||
|
xgboost::collective::MakeSockAddress(xgboost::StringView{tracker_uri}, tracker_port),
|
||||||
|
&tracker);
|
||||||
|
if (rc != std::errc()) {
|
||||||
if (++retry >= connect_retry) {
|
if (++retry >= connect_retry) {
|
||||||
LOG(WARNING) << "Connect to (failed): [" << tracker_uri << "]\n";
|
LOG(FATAL) << "Connecting to (failed): [" << tracker_uri << "]\n" << rc.message();
|
||||||
utils::Socket::Error("Connect");
|
|
||||||
} else {
|
} else {
|
||||||
LOG(WARNING) << "Retry connect to ip(retry time " << retry << "): [" << tracker_uri << "]\n";
|
LOG(WARNING) << rc.message() << "\nRetry connecting to IP(retry time: " << retry << "): ["
|
||||||
#if defined(_MSC_VER) || defined (__MINGW32__)
|
<< tracker_uri << "]";
|
||||||
|
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||||
Sleep(retry << 1);
|
Sleep(retry << 1);
|
||||||
#else
|
#else
|
||||||
sleep(retry << 1);
|
sleep(retry << 1);
|
||||||
@ -253,16 +253,13 @@ utils::TCPSocket AllreduceBase::ConnectTracker() const {
|
|||||||
} while (true);
|
} while (true);
|
||||||
|
|
||||||
using utils::Assert;
|
using utils::Assert;
|
||||||
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
|
CHECK_EQ(tracker.SendAll(&magic, sizeof(magic)), sizeof(magic));
|
||||||
"ReConnectLink failure 1");
|
CHECK_EQ(tracker.RecvAll(&magic, sizeof(magic)), sizeof(magic));
|
||||||
Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic),
|
|
||||||
"ReConnectLink failure 2");
|
|
||||||
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
|
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
|
||||||
Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
|
||||||
"ReConnectLink failure 3");
|
|
||||||
Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size),
|
Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size),
|
||||||
"ReConnectLink failure 3");
|
"ReConnectLink failure 3");
|
||||||
tracker.SendStr(task_id);
|
CHECK_EQ(tracker.Send(xgboost::StringView{task_id}), task_id.size());
|
||||||
return tracker;
|
return tracker;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -272,12 +269,15 @@ utils::TCPSocket AllreduceBase::ConnectTracker() const {
|
|||||||
bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||||
// single node mode
|
// single node mode
|
||||||
if (tracker_uri == "NULL") {
|
if (tracker_uri == "NULL") {
|
||||||
rank = 0; world_size = 1; return true;
|
rank = 0;
|
||||||
|
world_size = 1;
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
utils::TCPSocket tracker = this->ConnectTracker();
|
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
|
||||||
LOG(INFO) << "task " << task_id << " connected to the tracker";
|
LOG(INFO) << "task " << task_id << " connected to the tracker";
|
||||||
tracker.SendStr(std::string(cmd));
|
tracker.Send(xgboost::StringView{cmd});
|
||||||
|
|
||||||
// the rank of previous link, next link in ring
|
// the rank of previous link, next link in ring
|
||||||
int prev_rank, next_rank;
|
int prev_rank, next_rank;
|
||||||
@ -315,13 +315,9 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
|
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
|
||||||
"ReConnectLink failure 4");
|
"ReConnectLink failure 4");
|
||||||
|
|
||||||
utils::TCPSocket sock_listen;
|
auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())};
|
||||||
if (!sock_listen.IsClosed()) {
|
|
||||||
sock_listen.Close();
|
|
||||||
}
|
|
||||||
// create listening socket
|
// create listening socket
|
||||||
sock_listen.Create();
|
int port = sock_listen.BindHost();
|
||||||
int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial);
|
|
||||||
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
|
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
|
||||||
sock_listen.Listen();
|
sock_listen.Listen();
|
||||||
|
|
||||||
@ -338,29 +334,27 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
int ngood = static_cast<int>(good_link.size());
|
int ngood = static_cast<int>(good_link.size());
|
||||||
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
|
// tracker construct goodset
|
||||||
"ReConnectLink failure 5");
|
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood), "ReConnectLink failure 5");
|
||||||
for (int & i : good_link) {
|
for (int &i : good_link) {
|
||||||
Assert(tracker.SendAll(&i, sizeof(i)) == \
|
Assert(tracker.SendAll(&i, sizeof(i)) == sizeof(i), "ReConnectLink failure 6");
|
||||||
sizeof(i), "ReConnectLink failure 6");
|
|
||||||
}
|
}
|
||||||
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
||||||
"ReConnectLink failure 7");
|
"ReConnectLink failure 7");
|
||||||
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \
|
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept),
|
||||||
sizeof(num_accept), "ReConnectLink failure 8");
|
"ReConnectLink failure 8");
|
||||||
num_error = 0;
|
num_error = 0;
|
||||||
for (int i = 0; i < num_conn; ++i) {
|
for (int i = 0; i < num_conn; ++i) {
|
||||||
LinkRecord r;
|
LinkRecord r;
|
||||||
int hport, hrank;
|
int hport, hrank;
|
||||||
std::string hname;
|
std::string hname;
|
||||||
tracker.RecvStr(&hname);
|
tracker.Recv(&hname);
|
||||||
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport),
|
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
|
||||||
"ReConnectLink failure 9");
|
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
|
||||||
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank),
|
|
||||||
"ReConnectLink failure 10");
|
|
||||||
|
|
||||||
r.sock.Create();
|
if (xgboost::collective::Connect(
|
||||||
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
|
xgboost::collective::MakeSockAddress(xgboost::StringView{hname}, hport), &r.sock) !=
|
||||||
|
std::errc{}) {
|
||||||
num_error += 1;
|
num_error += 1;
|
||||||
r.sock.Close();
|
r.sock.Close();
|
||||||
continue;
|
continue;
|
||||||
@ -376,12 +370,12 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
if (all_link.rank == hrank) {
|
if (all_link.rank == hrank) {
|
||||||
Assert(all_link.sock.IsClosed(),
|
Assert(all_link.sock.IsClosed(),
|
||||||
"Override a link that is active");
|
"Override a link that is active");
|
||||||
all_link.sock = r.sock;
|
all_link.sock = std::move(r.sock);
|
||||||
match = true;
|
match = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!match) all_links.push_back(r);
|
if (!match) all_links.emplace_back(std::move(r));
|
||||||
}
|
}
|
||||||
Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
|
Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
|
||||||
"ReConnectLink failure 14");
|
"ReConnectLink failure 14");
|
||||||
@ -404,30 +398,24 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
if (all_link.rank == r.rank) {
|
if (all_link.rank == r.rank) {
|
||||||
utils::Assert(all_link.sock.IsClosed(),
|
utils::Assert(all_link.sock.IsClosed(),
|
||||||
"Override a link that is active");
|
"Override a link that is active");
|
||||||
all_link.sock = r.sock;
|
all_link.sock = std::move(r.sock);
|
||||||
match = true;
|
match = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!match) all_links.push_back(r);
|
if (!match) all_links.emplace_back(std::move(r));
|
||||||
}
|
}
|
||||||
sock_listen.Close();
|
sock_listen.Close();
|
||||||
this->parent_index = -1;
|
this->parent_index = -1;
|
||||||
// setup tree links and ring structure
|
// setup tree links and ring structure
|
||||||
tree_links.plinks.clear();
|
tree_links.plinks.clear();
|
||||||
int tcpNoDelay = 1;
|
for (auto &all_link : all_links) {
|
||||||
for (auto & all_link : all_links) {
|
|
||||||
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
|
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
|
||||||
// set the socket to non-blocking mode, enable TCP keepalive
|
// set the socket to non-blocking mode, enable TCP keepalive
|
||||||
all_link.sock.SetNonBlock(true);
|
all_link.sock.SetNonBlock();
|
||||||
all_link.sock.SetKeepAlive(true);
|
all_link.sock.SetKeepAlive();
|
||||||
if (rabit_enable_tcp_no_delay) {
|
if (rabit_enable_tcp_no_delay) {
|
||||||
#if defined(__unix__)
|
all_link.sock.SetNoDelay();
|
||||||
setsockopt(all_link.sock, IPPROTO_TCP,
|
|
||||||
TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay));
|
|
||||||
#else
|
|
||||||
LOG(WARNING) << "tcp no delay is not implemented on non unix platforms";
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
if (tree_neighbors.count(all_link.rank) != 0) {
|
if (tree_neighbors.count(all_link.rank) != 0) {
|
||||||
if (all_link.rank == parent_rank) {
|
if (all_link.rank == parent_rank) {
|
||||||
|
|||||||
@ -201,8 +201,8 @@ class AllreduceBase : public IEngine {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
/*! \brief translate errno to return type */
|
/*! \brief translate errno to return type */
|
||||||
inline static ReturnType Errno2Return() {
|
static ReturnType Errno2Return() {
|
||||||
int errsv = utils::Socket::GetLastError();
|
int errsv = xgboost::system::LastError();
|
||||||
if (errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == 0) return kSuccess;
|
if (errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == 0) return kSuccess;
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
if (errsv == WSAEWOULDBLOCK) return kSuccess;
|
if (errsv == WSAEWOULDBLOCK) return kSuccess;
|
||||||
@ -215,7 +215,7 @@ class AllreduceBase : public IEngine {
|
|||||||
struct LinkRecord {
|
struct LinkRecord {
|
||||||
public:
|
public:
|
||||||
// socket to get data from/to link
|
// socket to get data from/to link
|
||||||
utils::TCPSocket sock;
|
xgboost::collective::TCPSocket sock;
|
||||||
// rank of the node in this link
|
// rank of the node in this link
|
||||||
int rank;
|
int rank;
|
||||||
// size of data readed from link
|
// size of data readed from link
|
||||||
@ -329,7 +329,7 @@ class AllreduceBase : public IEngine {
|
|||||||
* \brief initialize connection to the tracker
|
* \brief initialize connection to the tracker
|
||||||
* \return a socket that initializes the connection
|
* \return a socket that initializes the connection
|
||||||
*/
|
*/
|
||||||
utils::TCPSocket ConnectTracker() const;
|
xgboost::collective::TCPSocket ConnectTracker() const;
|
||||||
/*!
|
/*!
|
||||||
* \brief connect to the tracker to fix the the missing links
|
* \brief connect to the tracker to fix the the missing links
|
||||||
* this function is also used when the engine start up
|
* this function is also used when the engine start up
|
||||||
@ -473,8 +473,6 @@ class AllreduceBase : public IEngine {
|
|||||||
std::string dmlc_role; // NOLINT
|
std::string dmlc_role; // NOLINT
|
||||||
// port of tracker address
|
// port of tracker address
|
||||||
int tracker_port; // NOLINT
|
int tracker_port; // NOLINT
|
||||||
// port of slave process
|
|
||||||
int slave_port, nport_trial; // NOLINT
|
|
||||||
// reduce buffer size
|
// reduce buffer size
|
||||||
size_t reduce_buffer_size; // NOLINT
|
size_t reduce_buffer_size; // NOLINT
|
||||||
// reduction method
|
// reduction method
|
||||||
|
|||||||
94
src/collective/socket.cc
Normal file
94
src/collective/socket.cc
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright (c) 2022 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include "xgboost/collective/socket.h"
|
||||||
|
|
||||||
|
#include <cstddef> // std::size_t
|
||||||
|
#include <cstdint> // std::int32_t
|
||||||
|
#include <cstring> // std::memcpy, std::memset
|
||||||
|
#include <system_error> // std::error_code, std::system_category
|
||||||
|
|
||||||
|
#if defined(__unix__) || defined(__APPLE__)
|
||||||
|
#include <netdb.h> // getaddrinfo, freeaddrinfo
|
||||||
|
#endif // defined(__unix__) || defined(__APPLE__)
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
SockAddress MakeSockAddress(StringView host, in_port_t port) {
|
||||||
|
struct addrinfo hints;
|
||||||
|
std::memset(&hints, 0, sizeof(hints));
|
||||||
|
hints.ai_protocol = SOCK_STREAM;
|
||||||
|
struct addrinfo *res = nullptr;
|
||||||
|
int sig = getaddrinfo(host.c_str(), nullptr, &hints, &res);
|
||||||
|
if (sig != 0) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
if (res->ai_family == static_cast<std::int32_t>(SockDomain::kV4)) {
|
||||||
|
sockaddr_in addr;
|
||||||
|
std::memcpy(&addr, res->ai_addr, res->ai_addrlen);
|
||||||
|
addr.sin_port = htons(port);
|
||||||
|
auto v = SockAddrV4{addr};
|
||||||
|
freeaddrinfo(res);
|
||||||
|
return SockAddress{v};
|
||||||
|
} else if (res->ai_family == static_cast<std::int32_t>(SockDomain::kV6)) {
|
||||||
|
sockaddr_in6 addr;
|
||||||
|
std::memcpy(&addr, res->ai_addr, res->ai_addrlen);
|
||||||
|
|
||||||
|
addr.sin6_port = htons(port);
|
||||||
|
auto v = SockAddrV6{addr};
|
||||||
|
freeaddrinfo(res);
|
||||||
|
return SockAddress{v};
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << "Failed to get addr info for: " << host;
|
||||||
|
}
|
||||||
|
|
||||||
|
return SockAddress{};
|
||||||
|
}
|
||||||
|
|
||||||
|
SockAddrV4 SockAddrV4::Loopback() { return MakeSockAddress("127.0.0.1", 0).V4(); }
|
||||||
|
SockAddrV4 SockAddrV4::InaddrAny() { return MakeSockAddress("0.0.0.0", 0).V4(); }
|
||||||
|
|
||||||
|
SockAddrV6 SockAddrV6::Loopback() { return MakeSockAddress("::1", 0).V6(); }
|
||||||
|
SockAddrV6 SockAddrV6::InaddrAny() { return MakeSockAddress("::", 0).V6(); }
|
||||||
|
|
||||||
|
std::size_t TCPSocket::Send(StringView str) {
|
||||||
|
CHECK(!this->IsClosed());
|
||||||
|
CHECK_LT(str.size(), std::numeric_limits<std::int32_t>::max());
|
||||||
|
std::int32_t len = static_cast<std::int32_t>(str.size());
|
||||||
|
CHECK_EQ(this->SendAll(&len, sizeof(len)), sizeof(len)) << "Failed to send string length.";
|
||||||
|
auto bytes = this->SendAll(str.c_str(), str.size());
|
||||||
|
CHECK_EQ(bytes, str.size()) << "Failed to send string.";
|
||||||
|
return bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t TCPSocket::Recv(std::string *p_str) {
|
||||||
|
CHECK(!this->IsClosed());
|
||||||
|
std::int32_t len;
|
||||||
|
CHECK_EQ(this->RecvAll(&len, sizeof(len)), sizeof(len)) << "Failed to recv string length.";
|
||||||
|
p_str->resize(len);
|
||||||
|
auto bytes = this->RecvAll(&(*p_str)[0], len);
|
||||||
|
CHECK_EQ(bytes, len) << "Failed to recv string.";
|
||||||
|
return bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::error_code Connect(SockAddress const &addr, TCPSocket *out) {
|
||||||
|
sockaddr const *addr_handle{nullptr};
|
||||||
|
socklen_t addr_len{0};
|
||||||
|
if (addr.IsV4()) {
|
||||||
|
addr_handle = reinterpret_cast<const sockaddr *>(&addr.V4().Handle());
|
||||||
|
addr_len = sizeof(addr.V4().Handle());
|
||||||
|
} else {
|
||||||
|
addr_handle = reinterpret_cast<const sockaddr *>(&addr.V6().Handle());
|
||||||
|
addr_len = sizeof(addr.V6().Handle());
|
||||||
|
}
|
||||||
|
auto socket = TCPSocket::Create(addr.Domain());
|
||||||
|
CHECK_EQ(static_cast<std::int32_t>(socket.Domain()), static_cast<std::int32_t>(addr.Domain()));
|
||||||
|
auto rc = connect(socket.Handle(), addr_handle, addr_len);
|
||||||
|
if (rc != 0) {
|
||||||
|
return std::error_code{errno, std::system_category()};
|
||||||
|
}
|
||||||
|
*out = std::move(socket);
|
||||||
|
return std::make_error_code(std::errc{});
|
||||||
|
}
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
@ -121,6 +121,7 @@ if __name__ == "__main__":
|
|||||||
"python-package/xgboost/sklearn.py",
|
"python-package/xgboost/sklearn.py",
|
||||||
"python-package/xgboost/spark",
|
"python-package/xgboost/spark",
|
||||||
"python-package/xgboost/federated.py",
|
"python-package/xgboost/federated.py",
|
||||||
|
"python-package/xgboost/testing.py",
|
||||||
# tests
|
# tests
|
||||||
"tests/python/test_config.py",
|
"tests/python/test_config.py",
|
||||||
"tests/python/test_spark/",
|
"tests/python/test_spark/",
|
||||||
|
|||||||
77
tests/cpp/collective/test_socket.cc
Normal file
77
tests/cpp/collective/test_socket.cc
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright (c) 2022 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/collective/socket.h>
|
||||||
|
|
||||||
|
#include <cerrno> // EADDRNOTAVAIL
|
||||||
|
#include <fstream> // ifstream
|
||||||
|
#include <system_error> // std::error_code, std::system_category
|
||||||
|
|
||||||
|
#include "../helpers.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
TEST(Socket, Basic) {
|
||||||
|
system::SocketStartup();
|
||||||
|
|
||||||
|
SockAddress addr{SockAddrV6::Loopback()};
|
||||||
|
ASSERT_TRUE(addr.IsV6());
|
||||||
|
addr = SockAddress{SockAddrV4::Loopback()};
|
||||||
|
ASSERT_TRUE(addr.IsV4());
|
||||||
|
|
||||||
|
std::string msg{"Skipping IPv6 test"};
|
||||||
|
|
||||||
|
auto run_test = [msg](SockDomain domain) {
|
||||||
|
auto server = TCPSocket::Create(domain);
|
||||||
|
ASSERT_EQ(server.Domain(), domain);
|
||||||
|
auto port = server.BindHost();
|
||||||
|
server.Listen();
|
||||||
|
|
||||||
|
TCPSocket client;
|
||||||
|
if (domain == SockDomain::kV4) {
|
||||||
|
auto const& addr = SockAddrV4::Loopback().Addr();
|
||||||
|
ASSERT_EQ(Connect(MakeSockAddress(StringView{addr}, port), &client), std::errc{});
|
||||||
|
} else {
|
||||||
|
auto const& addr = SockAddrV6::Loopback().Addr();
|
||||||
|
auto rc = Connect(MakeSockAddress(StringView{addr}, port), &client);
|
||||||
|
// some environment (docker) has restricted network configuration.
|
||||||
|
if (rc == std::error_code{EADDRNOTAVAIL, std::system_category()}) {
|
||||||
|
GTEST_SKIP_(msg.c_str());
|
||||||
|
}
|
||||||
|
ASSERT_EQ(rc, std::errc{});
|
||||||
|
}
|
||||||
|
ASSERT_EQ(client.Domain(), domain);
|
||||||
|
|
||||||
|
auto accepted = server.Accept();
|
||||||
|
StringView msg{"Hello world."};
|
||||||
|
accepted.Send(msg);
|
||||||
|
|
||||||
|
std::string str;
|
||||||
|
client.Recv(&str);
|
||||||
|
ASSERT_EQ(StringView{str}, msg);
|
||||||
|
};
|
||||||
|
|
||||||
|
run_test(SockDomain::kV4);
|
||||||
|
|
||||||
|
std::string path{"/sys/module/ipv6/parameters/disable"};
|
||||||
|
if (FileExists(path)) {
|
||||||
|
std::ifstream fin(path);
|
||||||
|
if (!fin) {
|
||||||
|
GTEST_SKIP_(msg.c_str());
|
||||||
|
}
|
||||||
|
std::string s_value;
|
||||||
|
fin >> s_value;
|
||||||
|
auto value = std::stoi(s_value);
|
||||||
|
if (value != 0) {
|
||||||
|
GTEST_SKIP_(msg.c_str());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
GTEST_SKIP_(msg.c_str());
|
||||||
|
}
|
||||||
|
run_test(SockDomain::kV6);
|
||||||
|
|
||||||
|
system::SocketFinalize();
|
||||||
|
}
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
@ -1,7 +1,6 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright (c) 2022 by XGBoost Contributors
|
* Copyright (c) 2022 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef XGBOOST_TESTS_CPP_FILESYSTEM_H
|
#ifndef XGBOOST_TESTS_CPP_FILESYSTEM_H
|
||||||
#define XGBOOST_TESTS_CPP_FILESYSTEM_H
|
#define XGBOOST_TESTS_CPP_FILESYSTEM_H
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,12 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2021-2022, XGBoost contributors.
|
* Copyright 2021-2022, XGBoost contributors.
|
||||||
*/
|
*/
|
||||||
|
#ifndef XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
|
||||||
|
#define XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
|
||||||
#include <xgboost/tree_model.h>
|
#include <xgboost/tree_model.h>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../../../src/tree/hist/expand_entry.h"
|
#include "../../../src/tree/hist/expand_entry.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -19,3 +23,4 @@ inline void GetSplit(RegTree *tree, float split_value, std::vector<CPUExpandEntr
|
|||||||
}
|
}
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
#endif // XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
|
||||||
|
|||||||
@ -1,34 +1,37 @@
|
|||||||
from xgboost import RabitTracker
|
import re
|
||||||
import xgboost as xgb
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import testing as tm
|
import testing as tm
|
||||||
import numpy as np
|
|
||||||
import sys
|
import xgboost as xgb
|
||||||
import re
|
from xgboost import RabitTracker, testing
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||||
|
|
||||||
|
|
||||||
def test_rabit_tracker():
|
def test_rabit_tracker():
|
||||||
tracker = RabitTracker(host_ip='127.0.0.1', n_workers=1)
|
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
|
||||||
tracker.start(1)
|
tracker.start(1)
|
||||||
worker_env = tracker.worker_envs()
|
worker_env = tracker.worker_envs()
|
||||||
rabit_env = []
|
rabit_env = []
|
||||||
for k, v in worker_env.items():
|
for k, v in worker_env.items():
|
||||||
rabit_env.append(f"{k}={v}".encode())
|
rabit_env.append(f"{k}={v}".encode())
|
||||||
with xgb.rabit.RabitContext(rabit_env):
|
with xgb.rabit.RabitContext(rabit_env):
|
||||||
ret = xgb.rabit.broadcast('test1234', 0)
|
ret = xgb.rabit.broadcast("test1234", 0)
|
||||||
assert str(ret) == 'test1234'
|
assert str(ret) == "test1234"
|
||||||
|
|
||||||
|
|
||||||
def run_rabit_ops(client, n_workers):
|
def run_rabit_ops(client, n_workers):
|
||||||
from test_with_dask import _get_client_workers
|
from test_with_dask import _get_client_workers
|
||||||
from xgboost.dask import RabitContext, _get_rabit_args
|
from xgboost.dask import RabitContext, _get_dask_config, _get_rabit_args
|
||||||
|
|
||||||
from xgboost import rabit
|
from xgboost import rabit
|
||||||
|
|
||||||
workers = _get_client_workers(client)
|
workers = _get_client_workers(client)
|
||||||
rabit_args = client.sync(_get_rabit_args, len(workers), None, client)
|
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
|
||||||
assert not rabit.is_distributed()
|
assert not rabit.is_distributed()
|
||||||
n_workers_from_dask = len(workers)
|
n_workers_from_dask = len(workers)
|
||||||
assert n_workers == n_workers_from_dask
|
assert n_workers == n_workers_from_dask
|
||||||
@ -55,12 +58,26 @@ def run_rabit_ops(client, n_workers):
|
|||||||
@pytest.mark.skipif(**tm.no_dask())
|
@pytest.mark.skipif(**tm.no_dask())
|
||||||
def test_rabit_ops():
|
def test_rabit_ops():
|
||||||
from distributed import Client, LocalCluster
|
from distributed import Client, LocalCluster
|
||||||
|
|
||||||
n_workers = 3
|
n_workers = 3
|
||||||
with LocalCluster(n_workers=n_workers) as cluster:
|
with LocalCluster(n_workers=n_workers) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
run_rabit_ops(client, n_workers)
|
run_rabit_ops(client, n_workers)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**testing.skip_ipv6())
|
||||||
|
@pytest.mark.skipif(**tm.no_dask())
|
||||||
|
def test_rabit_ops_ipv6():
|
||||||
|
import dask
|
||||||
|
from distributed import Client, LocalCluster
|
||||||
|
|
||||||
|
n_workers = 3
|
||||||
|
with dask.config.set({"xgboost.scheduler_address": "[::1]"}):
|
||||||
|
with LocalCluster(n_workers=n_workers, host="[::1]") as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
|
run_rabit_ops(client, n_workers)
|
||||||
|
|
||||||
|
|
||||||
def test_rank_assignment() -> None:
|
def test_rank_assignment() -> None:
|
||||||
from distributed import Client, LocalCluster
|
from distributed import Client, LocalCluster
|
||||||
from test_with_dask import _get_client_workers
|
from test_with_dask import _get_client_workers
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user