Initial support for IPv6 (#8225)

- Merge rabit socket into XGBoost.
- Dask interface support.
- Add test to the socket.
This commit is contained in:
Jiaming Yuan 2022-09-21 18:06:50 +08:00 committed by GitHub
parent 7d43e74e71
commit b791446623
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 924 additions and 595 deletions

View File

@ -91,6 +91,9 @@
#include "../src/common/timer.cc"
#include "../src/common/version.cc"
// collective
#include "../src/collective/socket.cc"
// c_api
#include "../src/c_api/c_api.cc"
#include "../src/c_api/c_api_error.cc"

View File

@ -204,7 +204,7 @@ latex_documents = [
]
intersphinx_mapping = {
"python": ("https://docs.python.org/3.6", None),
"python": ("https://docs.python.org/3.8", None),
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/reference/", None),
"pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None),

View File

@ -474,7 +474,6 @@ interface, including callback functions, custom evaluation metric and objective:
callbacks=[early_stop],
)
.. _tracker-ip:
***************
@ -504,6 +503,35 @@ dask config is used:
reg = dxgb.DaskXGBRegressor()
************
IPv6 Support
************
.. versionadded:: 2.0.0
XGBoost has initial IPv6 support for the dask interface on Linux. Due to most of the
cluster support for IPv6 is partial (dual stack instead of IPv6 only), we require
additional user configuration similar to :ref:`tracker-ip` to help XGBoost obtain the
correct address information:
.. code-block:: python
import dask
from distributed import Client
from xgboost import dask as dxgb
# let xgboost know the scheduler address, use the same bracket format as dask.
with dask.config.set({"xgboost.scheduler_address": "[fd20:b6f:f759:9800::]"}):
with Client("[fd20:b6f:f759:9800::]") as client:
reg = dxgb.DaskXGBRegressor(tree_method="hist")
When GPU is used, XGBoost employs `NCCL <https://developer.nvidia.com/nccl>`_ as the
underlying communication framework, which may require some additional configuration via
environment variable depending on the setting of the cluster. Please note that IPv6
support is Unix only.
*****************************************************************************
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
*****************************************************************************

View File

@ -0,0 +1,536 @@
/*!
* Copyright (c) 2022 by XGBoost Contributors
*/
#pragma once
#if !defined(NOMINMAX) && defined(_WIN32)
#define NOMINMAX
#endif // !defined(NOMINMAX)
#include <cerrno> // errno, EINTR, EBADF
#include <climits> // HOST_NAME_MAX
#include <cstddef> // std::size_t
#include <cstdint> // std::int32_t, std::uint16_t
#include <cstring> // memset
#include <limits> // std::numeric_limits
#include <string> // std::string
#include <system_error> // std::error_code, std::system_category
#include <utility> // std::swap
#if !defined(xgboost_IS_MINGW)
#define xgboost_IS_MINGW() defined(__MINGW32__)
#endif // xgboost_IS_MINGW
#if defined(_WIN32)
#include <winsock2.h>
#include <ws2tcpip.h>
using in_port_t = std::uint16_t;
#ifdef _MSC_VER
#pragma comment(lib, "Ws2_32.lib")
#endif // _MSC_VER
#if !xgboost_IS_MINGW()
using ssize_t = int;
#endif // !xgboost_IS_MINGW()
#else // UNIX
#include <arpa/inet.h> // inet_ntop
#include <fcntl.h> // fcntl, F_GETFL, O_NONBLOCK
#include <netinet/in.h> // sockaddr_in6, sockaddr_in, in_port_t, INET6_ADDRSTRLEN, INET_ADDRSTRLEN
#include <netinet/in.h> // IPPROTO_TCP
#include <netinet/tcp.h> // TCP_NODELAY
#include <sys/socket.h> // socket, SOL_SOCKET, SO_ERROR, MSG_WAITALL, recv, send, AF_INET6, AF_INET
#include <unistd.h> // close
#if defined(__sun) || defined(sun)
#include <sys/sockio.h>
#endif // defined(__sun) || defined(sun)
#endif // defined(_WIN32)
#include "xgboost/base.h" // XGBOOST_EXPECT
#include "xgboost/logging.h" // LOG
#include "xgboost/string_view.h" // StringView
#if !defined(HOST_NAME_MAX)
#define HOST_NAME_MAX 256 // macos
#endif
namespace xgboost {
#if xgboost_IS_MINGW()
// see the dummy implementation of `poll` in rabit for more info.
inline void MingWError() { LOG(FATAL) << "Distributed training on mingw is not supported."; }
#endif // xgboost_IS_MINGW()
namespace system {
inline std::int32_t LastError() {
#if defined(_WIN32)
return WSAGetLastError();
#else
int errsv = errno;
return errsv;
#endif
}
#if defined(__GLIBC__)
inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError(),
std::int32_t line = __builtin_LINE(),
char const *file = __builtin_FILE()) {
auto err = std::error_code{errsv, std::system_category()};
LOG(FATAL) << "\n"
<< file << "(" << line << "): Failed to call `" << fn_name << "`: " << err.message()
<< std::endl;
}
#else
inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError()) {
auto err = std::error_code{errsv, std::system_category()};
LOG(FATAL) << "Failed to call `" << fn_name << "`: " << err.message() << std::endl;
}
#endif // defined(__GLIBC__)
#if defined(_WIN32)
using SocketT = SOCKET;
#else
using SocketT = int;
#endif // defined(_WIN32)
#if !defined(xgboost_CHECK_SYS_CALL)
#define xgboost_CHECK_SYS_CALL(exp, expected) \
do { \
if (XGBOOST_EXPECT((exp) != (expected), false)) { \
::xgboost::system::ThrowAtError(#exp); \
} \
} while (false)
#endif // !defined(xgboost_CHECK_SYS_CALL)
inline std::int32_t CloseSocket(SocketT fd) {
#if defined(_WIN32)
return closesocket(fd);
#else
return close(fd);
#endif
}
inline bool LastErrorWouldBlock() {
int errsv = LastError();
#ifdef _WIN32
return errsv == WSAEWOULDBLOCK;
#else
return errsv == EAGAIN || errsv == EWOULDBLOCK;
#endif // _WIN32
}
inline void SocketStartup() {
#if defined(_WIN32)
WSADATA wsa_data;
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
ThrowAtError("WSAStartup");
}
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
WSACleanup();
LOG(FATAL) << "Could not find a usable version of Winsock.dll";
}
#endif // defined(_WIN32)
}
inline void SocketFinalize() {
#if defined(_WIN32)
WSACleanup();
#endif // defined(_WIN32)
}
#if defined(_WIN32) && xgboost_IS_MINGW()
// dummy definition for old mysys32.
inline const char *inet_ntop(int, const void *, char *, socklen_t) { // NOLINT
MingWError();
return nullptr;
}
#else
using ::inet_ntop;
#endif
} // namespace system
namespace collective {
class SockAddress;
enum class SockDomain : std::int32_t { kV4 = AF_INET, kV6 = AF_INET6 };
/**
* \brief Parse host address and return a SockAddress instance. Supports IPv4 and IPv6
* host.
*/
SockAddress MakeSockAddress(StringView host, in_port_t port);
class SockAddrV6 {
sockaddr_in6 addr_;
public:
explicit SockAddrV6(sockaddr_in6 addr) : addr_{addr} {}
SockAddrV6() { std::memset(&addr_, '\0', sizeof(addr_)); }
static SockAddrV6 Loopback();
static SockAddrV6 InaddrAny();
in_port_t Port() const { return ntohs(addr_.sin6_port); }
std::string Addr() const {
char buf[INET6_ADDRSTRLEN];
auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV6), &addr_.sin6_addr,
buf, INET6_ADDRSTRLEN);
if (s == nullptr) {
system::ThrowAtError("inet_ntop");
}
return {buf};
}
sockaddr_in6 const &Handle() const { return addr_; }
};
class SockAddrV4 {
private:
sockaddr_in addr_;
public:
explicit SockAddrV4(sockaddr_in addr) : addr_{addr} {}
SockAddrV4() { std::memset(&addr_, '\0', sizeof(addr_)); }
static SockAddrV4 Loopback();
static SockAddrV4 InaddrAny();
in_port_t Port() const { return ntohs(addr_.sin_port); }
std::string Addr() const {
char buf[INET_ADDRSTRLEN];
auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV4), &addr_.sin_addr,
buf, INET_ADDRSTRLEN);
if (s == nullptr) {
system::ThrowAtError("inet_ntop");
}
return {buf};
}
sockaddr_in const &Handle() const { return addr_; }
};
/**
* \brief Address for TCP socket, can be either IPv4 or IPv6.
*/
class SockAddress {
private:
SockAddrV6 v6_;
SockAddrV4 v4_;
SockDomain domain_{SockDomain::kV4};
public:
SockAddress() = default;
explicit SockAddress(SockAddrV6 const &addr) : v6_{addr}, domain_{SockDomain::kV6} {}
explicit SockAddress(SockAddrV4 const &addr) : v4_{addr} {}
auto Domain() const { return domain_; }
bool IsV4() const { return Domain() == SockDomain::kV4; }
bool IsV6() const { return !IsV4(); }
auto const &V4() const { return v4_; }
auto const &V6() const { return v6_; }
};
/**
* \brief TCP socket for simple communication.
*/
class TCPSocket {
public:
using HandleT = system::SocketT;
private:
HandleT handle_{InvalidSocket()};
// There's reliable no way to extract domain from a socket without first binding that
// socket on macos.
#if defined(__APPLE__)
SockDomain domain_{SockDomain::kV4};
#endif
constexpr static HandleT InvalidSocket() { return -1; }
explicit TCPSocket(HandleT newfd) : handle_{newfd} {}
public:
TCPSocket() = default;
/**
* \brief Return the socket domain.
*/
auto Domain() const -> SockDomain {
auto ret_iafamily = [](std::int32_t domain) {
switch (domain) {
case AF_INET:
return SockDomain::kV4;
case AF_INET6:
return SockDomain::kV6;
default: {
LOG(FATAL) << "Unknown IA family.";
}
}
return SockDomain::kV4;
};
#if defined(_WIN32)
WSAPROTOCOL_INFOA info;
socklen_t len = sizeof(info);
xgboost_CHECK_SYS_CALL(
getsockopt(handle_, SOL_SOCKET, SO_PROTOCOL_INFO, reinterpret_cast<char *>(&info), &len),
0);
return ret_iafamily(info.iAddressFamily);
#elif defined(__APPLE__)
return domain_;
#elif defined(__unix__)
std::int32_t domain;
socklen_t len = sizeof(domain);
xgboost_CHECK_SYS_CALL(
getsockopt(handle_, SOL_SOCKET, SO_DOMAIN, reinterpret_cast<char *>(&domain), &len), 0);
return ret_iafamily(domain);
#else
LOG(FATAL) << "Unknown platform.";
return ret_iafamily(AF_INET);
#endif // platforms
}
bool IsClosed() const { return handle_ == InvalidSocket(); }
/** \brief get last error code if any */
std::int32_t GetSockError() const {
std::int32_t error = 0;
socklen_t len = sizeof(error);
xgboost_CHECK_SYS_CALL(
getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&error), &len), 0);
return error;
}
/** \brief check if anything bad happens */
bool BadSocket() const {
if (IsClosed()) return true;
std::int32_t err = GetSockError();
if (err == EBADF || err == EINTR) return true;
return false;
}
void SetNonBlock() {
bool non_block{true};
#if defined(_WIN32)
u_long mode = non_block ? 1 : 0;
xgboost_CHECK_SYS_CALL(ioctlsocket(handle_, FIONBIO, &mode), NO_ERROR);
#else
std::int32_t flag = fcntl(handle_, F_GETFL, 0);
if (flag == -1) {
system::ThrowAtError("fcntl");
}
if (non_block) {
flag |= O_NONBLOCK;
} else {
flag &= ~O_NONBLOCK;
}
if (fcntl(handle_, F_SETFL, flag) == -1) {
system::ThrowAtError("fcntl");
}
#endif // _WIN32
}
void SetKeepAlive() {
std::int32_t keepalive = 1;
xgboost_CHECK_SYS_CALL(setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE,
reinterpret_cast<char *>(&keepalive), sizeof(keepalive)),
0);
}
void SetNoDelay() {
std::int32_t tcp_no_delay = 1;
xgboost_CHECK_SYS_CALL(
setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
sizeof(tcp_no_delay)),
0);
}
/**
* \brief Accept new connection, returns a new TCP socket for the new connection.
*/
TCPSocket Accept() {
HandleT newfd = accept(handle_, nullptr, nullptr);
if (newfd == InvalidSocket()) {
system::ThrowAtError("accept");
}
TCPSocket newsock{newfd};
return newsock;
}
~TCPSocket() {
if (!IsClosed()) {
Close();
}
}
TCPSocket(TCPSocket const &that) = delete;
TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); }
TCPSocket &operator=(TCPSocket const &that) = delete;
TCPSocket &operator=(TCPSocket &&that) {
std::swap(this->handle_, that.handle_);
return *this;
}
/**
* \brief Return the native socket file descriptor.
*/
HandleT const &Handle() const { return handle_; }
/**
* \brief Listen to incoming requests. Should be called after bind.
*/
void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); }
/**
* \brief Bind socket to INADDR_ANY, return the port selected by the OS.
*/
in_port_t BindHost() {
if (Domain() == SockDomain::kV6) {
auto addr = SockAddrV6::InaddrAny();
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
xgboost_CHECK_SYS_CALL(
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
sockaddr_in6 res_addr;
socklen_t addrlen = sizeof(res_addr);
xgboost_CHECK_SYS_CALL(
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
return ntohs(res_addr.sin6_port);
} else {
auto addr = SockAddrV4::InaddrAny();
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
xgboost_CHECK_SYS_CALL(
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
sockaddr_in res_addr;
socklen_t addrlen = sizeof(res_addr);
xgboost_CHECK_SYS_CALL(
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
return ntohs(res_addr.sin_port);
}
}
/**
* \brief Send data, without error then all data should be sent.
*/
auto SendAll(void const *buf, std::size_t len) {
char const *_buf = reinterpret_cast<const char *>(buf);
std::size_t ndone = 0;
while (ndone < len) {
ssize_t ret = send(handle_, _buf, len - ndone, 0);
if (ret == -1) {
if (system::LastErrorWouldBlock()) {
return ndone;
}
system::ThrowAtError("send");
}
_buf += ret;
ndone += ret;
}
return ndone;
}
/**
* \brief Receive data, without error then all data should be received.
*/
auto RecvAll(void *buf, std::size_t len) {
char *_buf = reinterpret_cast<char *>(buf);
std::size_t ndone = 0;
while (ndone < len) {
ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
if (ret == -1) {
if (system::LastErrorWouldBlock()) {
return ndone;
}
system::ThrowAtError("recv");
}
if (ret == 0) {
return ndone;
}
_buf += ret;
ndone += ret;
}
return ndone;
}
/**
* \brief Send data using the socket
* \param buf the pointer to the buffer
* \param len the size of the buffer
* \param flags extra flags
* \return size of data actually sent return -1 if error occurs
*/
auto Send(const void *buf_, std::size_t len, std::int32_t flags = 0) {
const char *buf = reinterpret_cast<const char *>(buf_);
return send(handle_, buf, len, flags);
}
/**
* \brief receive data using the socket
* \param buf the pointer to the buffer
* \param len the size of the buffer
* \param flags extra flags
* \return size of data actually received return -1 if error occurs
*/
auto Recv(void *buf, std::size_t len, std::int32_t flags = 0) {
char *_buf = reinterpret_cast<char *>(buf);
return recv(handle_, _buf, len, flags);
}
/**
* \brief Send string, format is matched with the Python socket wrapper in RABIT.
*/
std::size_t Send(StringView str);
/**
* \brief Receive string, format is matched with the Python socket wrapper in RABIT.
*/
std::size_t Recv(std::string *p_str);
/**
* \brief Close the socket, called automatically in destructor if the socket is not closed.
*/
void Close() {
if (InvalidSocket() != handle_) {
xgboost_CHECK_SYS_CALL(system::CloseSocket(handle_), 0);
handle_ = InvalidSocket();
}
}
/**
* \brief Create a TCP socket on specified domain.
*/
static TCPSocket Create(SockDomain domain) {
#if xgboost_IS_MINGW()
MingWError();
return {};
#else
auto fd = socket(static_cast<std::int32_t>(domain), SOCK_STREAM, 0);
if (fd == InvalidSocket()) {
system::ThrowAtError("socket");
}
TCPSocket socket{fd};
#if defined(__APPLE__)
socket.domain_ = domain;
#endif // defined(__APPLE__)
return socket;
#endif // xgboost_IS_MINGW()
}
};
/**
* \brief Connect to remote address, returns the error code if failed (no exception is
* raised so that we can retry).
*/
std::error_code Connect(SockAddress const &addr, TCPSocket *out);
/**
* \brief Get the local host name.
*/
inline std::string GetHostName() {
char buf[HOST_NAME_MAX];
xgboost_CHECK_SYS_CALL(gethostname(&buf[0], HOST_NAME_MAX), 0);
return buf;
}
} // namespace collective
} // namespace xgboost
#undef xgboost_CHECK_SYS_CALL
#undef xgboost_IS_MINGW

View File

@ -52,6 +52,7 @@ from typing import (
Sequence,
Set,
Tuple,
TypedDict,
TypeVar,
Union,
)
@ -102,19 +103,13 @@ else:
_DaskCollection = Union["da.Array", "dd.DataFrame", "dd.Series"]
_DataT = Union["da.Array", "dd.DataFrame"] # do not use series as predictor
try:
from mypy_extensions import TypedDict
TrainReturnT = TypedDict(
"TrainReturnT",
{
"booster": Booster,
"history": Dict,
},
)
except ImportError:
TrainReturnT = Dict[str, Any] # type:ignore
TrainReturnT = TypedDict(
"TrainReturnT",
{
"booster": Booster,
"history": Dict,
},
)
__all__ = [
"RabitContext",
@ -832,11 +827,15 @@ async def _get_rabit_args(
if k not in valid_config:
raise ValueError(f"Unknown configuration: {k}")
host_ip = dconfig.get("scheduler_address", None)
if host_ip is not None and host_ip.startswith("[") and host_ip.endswith("]"):
# convert dask bracket format to proper IPv6 address.
host_ip = host_ip[1:-1]
if host_ip is not None:
try:
host_ip, port = distributed.comm.get_address_host_port(host_ip)
except ValueError:
pass
if host_ip is not None:
user_addr = (host_ip, port)
else:

View File

@ -0,0 +1,41 @@
"""Utilities for defining Python tests."""
import socket
from platform import system
from typing import TypedDict
PytestSkip = TypedDict("PytestSkip", {"condition": bool, "reason": str})
def has_ipv6() -> bool:
"""Check whether IPv6 is enabled on this host."""
# connection error in macos, still need some fixes.
if system() not in ("Linux", "Windows"):
return False
if socket.has_ipv6:
try:
with socket.socket(
socket.AF_INET6, socket.SOCK_STREAM
) as server, socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as client:
server.bind(("::1", 0))
port = server.getsockname()[1]
server.listen()
client.connect(("::1", port))
conn, _ = server.accept()
client.sendall("abc".encode())
msg = conn.recv(3).decode()
# if the code can be executed to this point, the message should be
# correct.
assert msg == "abc"
return True
except OSError:
pass
return False
def skip_ipv6() -> PytestSkip:
"""PyTest skip mark for IPv6."""
return {"condition": not has_ipv6(), "reason": "IPv6 is required to be enabled."}

View File

@ -112,7 +112,7 @@ class WorkerEntry:
"""Assign the rank for current entry."""
self.rank = rank
nnset = set(tree_map[rank])
rprev, rnext = ring_map[rank]
rprev, next_rank = ring_map[rank]
self.sock.sendint(rank)
# send parent rank
self.sock.sendint(parent_map[rank])
@ -129,9 +129,9 @@ class WorkerEntry:
else:
self.sock.sendint(-1)
# send next link
if rnext not in (-1, rank):
nnset.add(rnext)
self.sock.sendint(rnext)
if next_rank not in (-1, rank):
nnset.add(next_rank)
self.sock.sendint(next_rank)
else:
self.sock.sendint(-1)
@ -157,6 +157,7 @@ class WorkerEntry:
self.sock.sendstr(wait_conn[r].host)
port = wait_conn[r].port
assert port is not None
# send port of this node to other workers so that they can call connect
self.sock.sendint(port)
self.sock.sendint(r)
nerr = self.sock.recvint()

View File

@ -1,65 +1,48 @@
/*!
* Copyright (c) 2014-2019 by Contributors
* Copyright (c) 2014-2022 by XGBoost Contributors
* \file socket.h
* \brief this file aims to provide a wrapper of sockets
* \author Tianqi Chen
*/
#ifndef RABIT_INTERNAL_SOCKET_H_
#define RABIT_INTERNAL_SOCKET_H_
#include "xgboost/collective/socket.h"
#if defined(_WIN32)
#include <winsock2.h>
#include <ws2tcpip.h>
#ifdef _MSC_VER
#pragma comment(lib, "Ws2_32.lib")
#endif // _MSC_VER
#else
#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#include <cerrno>
#include <unistd.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <unistd.h>
#if defined(__sun) || defined(sun)
#include <sys/sockio.h>
#endif // defined(__sun) || defined(sun)
#include <cerrno>
#endif // defined(_WIN32)
#include <string>
#include <cstring>
#include <vector>
#include <chrono>
#include <cstring>
#include <string>
#include <unordered_map>
#include <vector>
#include "utils.h"
#if defined(_WIN32) && !defined(__MINGW32__)
typedef int ssize_t;
#endif // defined(_WIN32) || defined(__MINGW32__)
#if defined(_WIN32)
using sock_size_t = int;
#else
#if !defined(_WIN32)
#include <sys/poll.h>
using SOCKET = int;
using sock_size_t = size_t; // NOLINT
#endif // defined(_WIN32)
#endif // !defined(_WIN32)
#define IS_MINGW() defined(__MINGW32__)
#if IS_MINGW()
inline void MingWError() {
throw dmlc::Error("Distributed training on mingw is not supported.");
}
#endif // IS_MINGW()
#if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
/*
* On later mingw versions poll should be supported (with bugs). See:
@ -88,23 +71,17 @@ typedef struct pollfd {
// POLLWRNORM
#define POLLOUT 0x0010
inline const char *inet_ntop(int, const void *, char *, size_t) {
MingWError();
return nullptr;
}
#endif // IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
namespace rabit {
namespace utils {
static constexpr int kInvalidSocket = -1;
template <typename PollFD>
int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) {
#if defined(_WIN32)
#if IS_MINGW()
MingWError();
xgboost::MingWError();
return -1;
#else
return WSAPoll(pfd, nfds, std::chrono::milliseconds(timeout).count());
@ -115,458 +92,6 @@ int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) {
#endif // IS_MINGW()
}
/*! \brief data structure for network address */
struct SockAddr {
sockaddr_in addr;
// constructor
SockAddr() = default;
SockAddr(const char *url, int port) {
this->Set(url, port);
}
inline static std::string GetHostName() {
std::string buf; buf.resize(256);
#if !IS_MINGW()
utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
#endif // IS_MINGW()
return std::string(buf.c_str());
}
/*!
* \brief set the address
* \param url the url of the address
* \param port the port of address
*/
inline void Set(const char *host, int port) {
#if !IS_MINGW()
addrinfo hints;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET;
hints.ai_protocol = SOCK_STREAM;
addrinfo *res = nullptr;
int sig = getaddrinfo(host, nullptr, &hints, &res);
Check(sig == 0 && res != nullptr, "cannot obtain address of %s", host);
Check(res->ai_family == AF_INET, "Does not support IPv6");
memcpy(&addr, res->ai_addr, res->ai_addrlen);
addr.sin_port = htons(port);
freeaddrinfo(res);
#endif // !IS_MINGW()
}
/*! \brief return port of the address*/
inline int Port() const {
return ntohs(addr.sin_port);
}
/*! \return a string representation of the address */
inline std::string AddrStr() const {
std::string buf; buf.resize(256);
#ifdef _WIN32
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
&buf[0], buf.length());
#else
const char *s = inet_ntop(AF_INET, &addr.sin_addr,
&buf[0], buf.length());
#endif // _WIN32
Assert(s != nullptr, "cannot decode address");
return std::string(s);
}
};
/*!
* \brief base class containing common operations of TCP and UDP sockets
*/
class Socket {
public:
/*! \brief the file descriptor of socket */
SOCKET sockfd;
// default conversion to int
operator SOCKET() const { // NOLINT
return sockfd;
}
/*!
* \return last error of socket operation
*/
inline static int GetLastError() {
#ifdef _WIN32
#if IS_MINGW()
MingWError();
return -1;
#else
return WSAGetLastError();
#endif // IS_MINGW()
#else
return errno;
#endif // _WIN32
}
/*! \return whether last error was would block */
inline static bool LastErrorWouldBlock() {
int errsv = GetLastError();
#ifdef _WIN32
return errsv == WSAEWOULDBLOCK;
#else
return errsv == EAGAIN || errsv == EWOULDBLOCK;
#endif // _WIN32
}
/*!
* \brief start up the socket module
* call this before using the sockets
*/
inline static void Startup() {
#ifdef _WIN32
#if !IS_MINGW()
WSADATA wsa_data;
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
Socket::Error("Startup");
}
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
WSACleanup();
utils::Error("Could not find a usable version of Winsock.dll\n");
}
#endif // !IS_MINGW()
#endif // _WIN32
}
/*!
* \brief shutdown the socket module after use, all sockets need to be closed
*/
inline static void Finalize() {
#ifdef _WIN32
#if !IS_MINGW()
WSACleanup();
#endif // !IS_MINGW()
#endif // _WIN32
}
/*!
* \brief set this socket to use non-blocking mode
* \param non_block whether set it to be non-block, if it is false
* it will set it back to block mode
*/
inline void SetNonBlock(bool non_block) {
#ifdef _WIN32
#if !IS_MINGW()
u_long mode = non_block ? 1 : 0;
if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
Socket::Error("SetNonBlock");
}
#endif // !IS_MINGW()
#else
int flag = fcntl(sockfd, F_GETFL, 0);
if (flag == -1) {
Socket::Error("SetNonBlock-1");
}
if (non_block) {
flag |= O_NONBLOCK;
} else {
flag &= ~O_NONBLOCK;
}
if (fcntl(sockfd, F_SETFL, flag) == -1) {
Socket::Error("SetNonBlock-2");
}
#endif // _WIN32
}
/*!
* \brief bind the socket to an address
* \param addr
*/
inline void Bind(const SockAddr &addr) {
#if !IS_MINGW()
if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
sizeof(addr.addr)) == -1) {
Socket::Error("Bind");
}
#endif // !IS_MINGW()
}
/*!
* \brief try bind the socket to host, from start_port to end_port
* \param start_port starting port number to try
* \param end_port ending port number to try
* \return the port successfully bind to, return -1 if failed to bind any port
*/
inline int TryBindHost(int start_port, int end_port) {
// TODO(tqchen) add prefix check
#if !IS_MINGW()
for (int port = start_port; port < end_port; ++port) {
SockAddr addr("0.0.0.0", port);
if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
sizeof(addr.addr)) == 0) {
return port;
}
#if defined(_WIN32)
if (WSAGetLastError() != WSAEADDRINUSE) {
Socket::Error("TryBindHost");
}
#else
if (errno != EADDRINUSE) {
Socket::Error("TryBindHost");
}
#endif // defined(_WIN32)
}
#endif // !IS_MINGW()
return -1;
}
/*! \brief get last error code if any */
inline int GetSockError() const {
int error = 0;
socklen_t len = sizeof(error);
#if !IS_MINGW()
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
reinterpret_cast<char *>(&error), &len) != 0) {
Error("GetSockError");
}
#else
// undefined reference to `_imp__getsockopt@20'
MingWError();
#endif // !IS_MINGW()
return error;
}
/*! \brief check if anything bad happens */
inline bool BadSocket() const {
if (IsClosed()) return true;
int err = GetSockError();
if (err == EBADF || err == EINTR) return true;
return false;
}
/*! \brief check if socket is already closed */
inline bool IsClosed() const {
return sockfd == kInvalidSocket;
}
/*! \brief close the socket */
inline void Close() {
if (sockfd != kInvalidSocket) {
#ifdef _WIN32
#if !IS_MINGW()
closesocket(sockfd);
#endif // !IS_MINGW()
#else
close(sockfd);
#endif
sockfd = kInvalidSocket;
} else {
Error("Socket::Close double close the socket or close without create");
}
}
// report an socket error
inline static void Error(const char *msg) {
int errsv = GetLastError();
#ifdef _WIN32
utils::Error("Socket %s Error:WSAError-code=%d", msg, errsv);
#else
utils::Error("Socket %s Error:%s", msg, strerror(errsv));
#endif
}
protected:
explicit Socket(SOCKET sockfd) : sockfd(sockfd) {
}
};
/*!
* \brief a wrapper of TCP socket that hopefully be cross platform
*/
class TCPSocket : public Socket{
public:
// constructor
TCPSocket() : Socket(kInvalidSocket) {
}
explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) {
}
/*!
* \brief enable/disable TCP keepalive
* \param keepalive whether to set the keep alive option on
*/
void SetKeepAlive(bool keepalive) {
#if !IS_MINGW()
int opt = static_cast<int>(keepalive);
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE,
reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
Socket::Error("SetKeepAlive");
}
#endif // !IS_MINGW()
}
inline void SetLinger(int timeout = 0) {
#if !IS_MINGW()
struct linger sl;
sl.l_onoff = 1; /* non-zero value enables linger option in kernel */
sl.l_linger = timeout; /* timeout interval in seconds */
if (setsockopt(sockfd, SOL_SOCKET, SO_LINGER, reinterpret_cast<char*>(&sl), sizeof(sl)) == -1) {
Socket::Error("SO_LINGER");
}
#endif // !IS_MINGW()
}
/*!
* \brief create the socket, call this before using socket
* \param af domain
*/
inline void Create(int af = PF_INET) {
#if !IS_MINGW()
sockfd = socket(af, SOCK_STREAM, 0);
if (sockfd == kInvalidSocket) {
Socket::Error("Create");
}
#endif // !IS_MINGW()
}
/*!
* \brief perform listen of the socket
* \param backlog backlog parameter
*/
inline void Listen(int backlog = 16) {
#if !IS_MINGW()
listen(sockfd, backlog);
#endif // !IS_MINGW()
}
/*! \brief get a new connection */
TCPSocket Accept() {
#if !IS_MINGW()
SOCKET newfd = accept(sockfd, nullptr, nullptr);
if (newfd == kInvalidSocket) {
Socket::Error("Accept");
}
return TCPSocket(newfd);
#else
return TCPSocket();
#endif // !IS_MINGW()
}
/*!
* \brief decide whether the socket is at OOB mark
* \return 1 if at mark, 0 if not, -1 if an error occured
*/
inline int AtMark() const {
#if !IS_MINGW()
#ifdef _WIN32
unsigned long atmark; // NOLINT(*)
if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
#else
int atmark;
if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1;
#endif // _WIN32
return static_cast<int>(atmark);
#else
return -1;
#endif // !IS_MINGW()
}
/*!
* \brief connect to an address
* \param addr the address to connect to
* \return whether connect is successful
*/
inline bool Connect(const SockAddr &addr) {
#if !IS_MINGW()
return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
sizeof(addr.addr)) == 0;
#else
return false;
#endif // !IS_MINGW()
}
/*!
* \brief send data using the socket
* \param buf the pointer to the buffer
* \param len the size of the buffer
* \param flags extra flags
* \return size of data actually sent
* return -1 if error occurs
*/
inline ssize_t Send(const void *buf_, size_t len, int flag = 0) {
const char *buf = reinterpret_cast<const char*>(buf_);
#if !IS_MINGW()
return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
#else
return 0;
#endif // !IS_MINGW()
}
/*!
* \brief receive data using the socket
* \param buf_ the pointer to the buffer
* \param len the size of the buffer
* \param flags extra flags
* \return size of data actually received
* return -1 if error occurs
*/
inline ssize_t Recv(void *buf_, size_t len, int flags = 0) {
char *buf = reinterpret_cast<char*>(buf_);
#if !IS_MINGW()
return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
#else
return 0;
#endif // !IS_MINGW()
}
/*!
* \brief peform block write that will attempt to send all data out
* can still return smaller than request when error occurs
* \param buf the pointer to the buffer
* \param len the size of the buffer
* \return size of data actually sent
*/
inline size_t SendAll(const void *buf_, size_t len) {
const char *buf = reinterpret_cast<const char*>(buf_);
size_t ndone = 0;
#if !IS_MINGW()
while (ndone < len) {
ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0);
if (ret == -1) {
if (LastErrorWouldBlock()) return ndone;
Socket::Error("SendAll");
}
buf += ret;
ndone += ret;
}
#endif // !IS_MINGW()
return ndone;
}
/*!
* \brief peforma block read that will attempt to read all data
* can still return smaller than request when error occurs
* \param buf_ the buffer pointer
* \param len length of data to recv
* \return size of data actually sent
*/
inline size_t RecvAll(void *buf_, size_t len) {
char *buf = reinterpret_cast<char*>(buf_);
size_t ndone = 0;
#if !IS_MINGW()
while (ndone < len) {
ssize_t ret = recv(sockfd, buf,
static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
if (ret == -1) {
if (LastErrorWouldBlock()) return ndone;
Socket::Error("RecvAll");
}
if (ret == 0) return ndone;
buf += ret;
ndone += ret;
}
#endif // !IS_MINGW()
return ndone;
}
/*!
* \brief send a string over network
* \param str the string to be sent
*/
inline void SendStr(const std::string &str) {
int len = static_cast<int>(str.length());
utils::Assert(this->SendAll(&len, sizeof(len)) == sizeof(len),
"error during send SendStr");
if (len != 0) {
utils::Assert(this->SendAll(str.c_str(), str.length()) == str.length(),
"error during send SendStr");
}
}
/*!
* \brief recv a string from network
* \param out_str the string to receive
*/
inline void RecvStr(std::string *out_str) {
int len;
utils::Assert(this->RecvAll(&len, sizeof(len)) == sizeof(len),
"error during send RecvStr");
out_str->resize(len);
if (len != 0) {
utils::Assert(this->RecvAll(&(*out_str)[0], len) == out_str->length(),
"error during send SendStr");
}
}
};
/*! \brief helper data structure to perform poll */
struct PollHelper {
public:
@ -579,6 +104,8 @@ struct PollHelper {
pfd.fd = fd;
pfd.events |= POLLIN;
}
void WatchRead(xgboost::collective::TCPSocket const &socket) { this->WatchRead(socket.Handle()); }
/*!
* \brief add file descriptor to watch for write
* \param fd file descriptor to be watched
@ -588,6 +115,10 @@ struct PollHelper {
pfd.fd = fd;
pfd.events |= POLLOUT;
}
void WatchWrite(xgboost::collective::TCPSocket const &socket) {
this->WatchWrite(socket.Handle());
}
/*!
* \brief add file descriptor to watch for exception
* \param fd file descriptor to be watched
@ -597,6 +128,9 @@ struct PollHelper {
pfd.fd = fd;
pfd.events |= POLLPRI;
}
void WatchException(xgboost::collective::TCPSocket const &socket) {
this->WatchException(socket.Handle());
}
/*!
* \brief Check if the descriptor is ready for read
* \param fd file descriptor to check status
@ -605,6 +139,10 @@ struct PollHelper {
const auto& pfd = fds.find(fd);
return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
}
bool CheckRead(xgboost::collective::TCPSocket const &socket) const {
return this->CheckRead(socket.Handle());
}
/*!
* \brief Check if the descriptor is ready for write
* \param fd file descriptor to check status
@ -613,7 +151,9 @@ struct PollHelper {
const auto& pfd = fds.find(fd);
return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
}
bool CheckWrite(xgboost::collective::TCPSocket const &socket) const {
return this->CheckWrite(socket.Handle());
}
/*!
* \brief perform poll on the set defined, read, write, exception
* \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
@ -629,7 +169,7 @@ struct PollHelper {
if (ret == 0) {
LOG(FATAL) << "Poll timeout";
} else if (ret < 0) {
Socket::Error("Poll");
LOG(FATAL) << "Failed to poll.";
} else {
for (auto& pfd : fdset) {
auto revents = pfd.revents & pfd.events;

View File

@ -8,15 +8,17 @@
#define RABIT_INTERNAL_UTILS_H_
#include <rabit/base.h>
#include <cstring>
#include <cstdarg>
#include <cstdio>
#include <string>
#include <cstdlib>
#include <cstring>
#include <stdexcept>
#include <string>
#include <vector>
#include "dmlc/io.h"
#include "xgboost/logging.h"
#include <cstdarg>
#if !defined(__GNUC__) || defined(__FreeBSD__)
#define fopen64 std::fopen

View File

@ -27,8 +27,6 @@ AllreduceBase::AllreduceBase() {
tracker_uri = "NULL";
tracker_port = 9000;
host_uri = "";
slave_port = 9010;
nport_trial = 1000;
rank = 0;
world_size = -1;
connect_retry = 5;
@ -114,16 +112,16 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
this->rank = -1;
//---------------------
// start socket
utils::Socket::Startup();
xgboost::system::SocketStartup();
utils::Assert(all_links.size() == 0, "can only call Init once");
this->host_uri = utils::SockAddr::GetHostName();
this->host_uri = xgboost::collective::GetHostName();
// get information from tracker
return this->ReConnectLinks();
}
bool AllreduceBase::Shutdown() {
try {
for (auto & all_link : all_links) {
for (auto &all_link : all_links) {
if (!all_link.sock.IsClosed()) {
all_link.sock.Close();
}
@ -133,12 +131,12 @@ bool AllreduceBase::Shutdown() {
if (tracker_uri == "NULL") return true;
// notify tracker rank i have shutdown
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("shutdown"));
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
tracker.Send(xgboost::StringView{"shutdown"});
tracker.Close();
utils::TCPSocket::Finalize();
xgboost::system::SocketFinalize();
return true;
} catch (const std::exception& e) {
} catch (std::exception const &e) {
LOG(WARNING) << "Failed to shutdown due to" << e.what();
return false;
}
@ -148,9 +146,9 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
if (tracker_uri == "NULL") {
utils::Printf("%s", msg.c_str()); return;
}
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("print"));
tracker.SendStr(msg);
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
tracker.Send(xgboost::StringView{"print"});
tracker.Send(xgboost::StringView{msg});
tracker.Close();
}
@ -227,21 +225,23 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
utils::TCPSocket AllreduceBase::ConnectTracker() const {
xgboost::collective::TCPSocket AllreduceBase::ConnectTracker() const {
int magic = kMagic;
// get information from tracker
utils::TCPSocket tracker;
tracker.Create();
xgboost::collective::TCPSocket tracker;
int retry = 0;
do {
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
auto rc = xgboost::collective::Connect(
xgboost::collective::MakeSockAddress(xgboost::StringView{tracker_uri}, tracker_port),
&tracker);
if (rc != std::errc()) {
if (++retry >= connect_retry) {
LOG(WARNING) << "Connect to (failed): [" << tracker_uri << "]\n";
utils::Socket::Error("Connect");
LOG(FATAL) << "Connecting to (failed): [" << tracker_uri << "]\n" << rc.message();
} else {
LOG(WARNING) << "Retry connect to ip(retry time " << retry << "): [" << tracker_uri << "]\n";
#if defined(_MSC_VER) || defined (__MINGW32__)
LOG(WARNING) << rc.message() << "\nRetry connecting to IP(retry time: " << retry << "): ["
<< tracker_uri << "]";
#if defined(_MSC_VER) || defined(__MINGW32__)
Sleep(retry << 1);
#else
sleep(retry << 1);
@ -253,16 +253,13 @@ utils::TCPSocket AllreduceBase::ConnectTracker() const {
} while (true);
using utils::Assert;
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
"ReConnectLink failure 1");
Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic),
"ReConnectLink failure 2");
CHECK_EQ(tracker.SendAll(&magic, sizeof(magic)), sizeof(magic));
CHECK_EQ(tracker.RecvAll(&magic, sizeof(magic)), sizeof(magic));
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 3");
Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size),
"ReConnectLink failure 3");
tracker.SendStr(task_id);
CHECK_EQ(tracker.Send(xgboost::StringView{task_id}), task_id.size());
return tracker;
}
/*!
@ -272,12 +269,15 @@ utils::TCPSocket AllreduceBase::ConnectTracker() const {
bool AllreduceBase::ReConnectLinks(const char *cmd) {
// single node mode
if (tracker_uri == "NULL") {
rank = 0; world_size = 1; return true;
rank = 0;
world_size = 1;
return true;
}
try {
utils::TCPSocket tracker = this->ConnectTracker();
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
LOG(INFO) << "task " << task_id << " connected to the tracker";
tracker.SendStr(std::string(cmd));
tracker.Send(xgboost::StringView{cmd});
// the rank of previous link, next link in ring
int prev_rank, next_rank;
@ -315,13 +315,9 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
"ReConnectLink failure 4");
utils::TCPSocket sock_listen;
if (!sock_listen.IsClosed()) {
sock_listen.Close();
}
auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())};
// create listening socket
sock_listen.Create();
int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial);
int port = sock_listen.BindHost();
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
sock_listen.Listen();
@ -338,29 +334,27 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
}
}
int ngood = static_cast<int>(good_link.size());
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
"ReConnectLink failure 5");
for (int & i : good_link) {
Assert(tracker.SendAll(&i, sizeof(i)) == \
sizeof(i), "ReConnectLink failure 6");
// tracker construct goodset
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood), "ReConnectLink failure 5");
for (int &i : good_link) {
Assert(tracker.SendAll(&i, sizeof(i)) == sizeof(i), "ReConnectLink failure 6");
}
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
"ReConnectLink failure 7");
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \
sizeof(num_accept), "ReConnectLink failure 8");
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept),
"ReConnectLink failure 8");
num_error = 0;
for (int i = 0; i < num_conn; ++i) {
LinkRecord r;
int hport, hrank;
std::string hname;
tracker.RecvStr(&hname);
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport),
"ReConnectLink failure 9");
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank),
"ReConnectLink failure 10");
tracker.Recv(&hname);
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
r.sock.Create();
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
if (xgboost::collective::Connect(
xgboost::collective::MakeSockAddress(xgboost::StringView{hname}, hport), &r.sock) !=
std::errc{}) {
num_error += 1;
r.sock.Close();
continue;
@ -376,12 +370,12 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
if (all_link.rank == hrank) {
Assert(all_link.sock.IsClosed(),
"Override a link that is active");
all_link.sock = r.sock;
all_link.sock = std::move(r.sock);
match = true;
break;
}
}
if (!match) all_links.push_back(r);
if (!match) all_links.emplace_back(std::move(r));
}
Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
"ReConnectLink failure 14");
@ -404,30 +398,24 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
if (all_link.rank == r.rank) {
utils::Assert(all_link.sock.IsClosed(),
"Override a link that is active");
all_link.sock = r.sock;
all_link.sock = std::move(r.sock);
match = true;
break;
}
}
if (!match) all_links.push_back(r);
if (!match) all_links.emplace_back(std::move(r));
}
sock_listen.Close();
this->parent_index = -1;
// setup tree links and ring structure
tree_links.plinks.clear();
int tcpNoDelay = 1;
for (auto & all_link : all_links) {
for (auto &all_link : all_links) {
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
// set the socket to non-blocking mode, enable TCP keepalive
all_link.sock.SetNonBlock(true);
all_link.sock.SetKeepAlive(true);
all_link.sock.SetNonBlock();
all_link.sock.SetKeepAlive();
if (rabit_enable_tcp_no_delay) {
#if defined(__unix__)
setsockopt(all_link.sock, IPPROTO_TCP,
TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay));
#else
LOG(WARNING) << "tcp no delay is not implemented on non unix platforms";
#endif
all_link.sock.SetNoDelay();
}
if (tree_neighbors.count(all_link.rank) != 0) {
if (all_link.rank == parent_rank) {

View File

@ -201,8 +201,8 @@ class AllreduceBase : public IEngine {
}
};
/*! \brief translate errno to return type */
inline static ReturnType Errno2Return() {
int errsv = utils::Socket::GetLastError();
static ReturnType Errno2Return() {
int errsv = xgboost::system::LastError();
if (errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == 0) return kSuccess;
#ifdef _WIN32
if (errsv == WSAEWOULDBLOCK) return kSuccess;
@ -215,7 +215,7 @@ class AllreduceBase : public IEngine {
struct LinkRecord {
public:
// socket to get data from/to link
utils::TCPSocket sock;
xgboost::collective::TCPSocket sock;
// rank of the node in this link
int rank;
// size of data readed from link
@ -329,7 +329,7 @@ class AllreduceBase : public IEngine {
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
utils::TCPSocket ConnectTracker() const;
xgboost::collective::TCPSocket ConnectTracker() const;
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
@ -473,8 +473,6 @@ class AllreduceBase : public IEngine {
std::string dmlc_role; // NOLINT
// port of tracker address
int tracker_port; // NOLINT
// port of slave process
int slave_port, nport_trial; // NOLINT
// reduce buffer size
size_t reduce_buffer_size; // NOLINT
// reduction method

94
src/collective/socket.cc Normal file
View File

@ -0,0 +1,94 @@
/*!
* Copyright (c) 2022 by XGBoost Contributors
*/
#include "xgboost/collective/socket.h"
#include <cstddef> // std::size_t
#include <cstdint> // std::int32_t
#include <cstring> // std::memcpy, std::memset
#include <system_error> // std::error_code, std::system_category
#if defined(__unix__) || defined(__APPLE__)
#include <netdb.h> // getaddrinfo, freeaddrinfo
#endif // defined(__unix__) || defined(__APPLE__)
namespace xgboost {
namespace collective {
SockAddress MakeSockAddress(StringView host, in_port_t port) {
struct addrinfo hints;
std::memset(&hints, 0, sizeof(hints));
hints.ai_protocol = SOCK_STREAM;
struct addrinfo *res = nullptr;
int sig = getaddrinfo(host.c_str(), nullptr, &hints, &res);
if (sig != 0) {
return {};
}
if (res->ai_family == static_cast<std::int32_t>(SockDomain::kV4)) {
sockaddr_in addr;
std::memcpy(&addr, res->ai_addr, res->ai_addrlen);
addr.sin_port = htons(port);
auto v = SockAddrV4{addr};
freeaddrinfo(res);
return SockAddress{v};
} else if (res->ai_family == static_cast<std::int32_t>(SockDomain::kV6)) {
sockaddr_in6 addr;
std::memcpy(&addr, res->ai_addr, res->ai_addrlen);
addr.sin6_port = htons(port);
auto v = SockAddrV6{addr};
freeaddrinfo(res);
return SockAddress{v};
} else {
LOG(FATAL) << "Failed to get addr info for: " << host;
}
return SockAddress{};
}
SockAddrV4 SockAddrV4::Loopback() { return MakeSockAddress("127.0.0.1", 0).V4(); }
SockAddrV4 SockAddrV4::InaddrAny() { return MakeSockAddress("0.0.0.0", 0).V4(); }
SockAddrV6 SockAddrV6::Loopback() { return MakeSockAddress("::1", 0).V6(); }
SockAddrV6 SockAddrV6::InaddrAny() { return MakeSockAddress("::", 0).V6(); }
std::size_t TCPSocket::Send(StringView str) {
CHECK(!this->IsClosed());
CHECK_LT(str.size(), std::numeric_limits<std::int32_t>::max());
std::int32_t len = static_cast<std::int32_t>(str.size());
CHECK_EQ(this->SendAll(&len, sizeof(len)), sizeof(len)) << "Failed to send string length.";
auto bytes = this->SendAll(str.c_str(), str.size());
CHECK_EQ(bytes, str.size()) << "Failed to send string.";
return bytes;
}
std::size_t TCPSocket::Recv(std::string *p_str) {
CHECK(!this->IsClosed());
std::int32_t len;
CHECK_EQ(this->RecvAll(&len, sizeof(len)), sizeof(len)) << "Failed to recv string length.";
p_str->resize(len);
auto bytes = this->RecvAll(&(*p_str)[0], len);
CHECK_EQ(bytes, len) << "Failed to recv string.";
return bytes;
}
std::error_code Connect(SockAddress const &addr, TCPSocket *out) {
sockaddr const *addr_handle{nullptr};
socklen_t addr_len{0};
if (addr.IsV4()) {
addr_handle = reinterpret_cast<const sockaddr *>(&addr.V4().Handle());
addr_len = sizeof(addr.V4().Handle());
} else {
addr_handle = reinterpret_cast<const sockaddr *>(&addr.V6().Handle());
addr_len = sizeof(addr.V6().Handle());
}
auto socket = TCPSocket::Create(addr.Domain());
CHECK_EQ(static_cast<std::int32_t>(socket.Domain()), static_cast<std::int32_t>(addr.Domain()));
auto rc = connect(socket.Handle(), addr_handle, addr_len);
if (rc != 0) {
return std::error_code{errno, std::system_category()};
}
*out = std::move(socket);
return std::make_error_code(std::errc{});
}
} // namespace collective
} // namespace xgboost

View File

@ -121,6 +121,7 @@ if __name__ == "__main__":
"python-package/xgboost/sklearn.py",
"python-package/xgboost/spark",
"python-package/xgboost/federated.py",
"python-package/xgboost/testing.py",
# tests
"tests/python/test_config.py",
"tests/python/test_spark/",

View File

@ -0,0 +1,77 @@
/*!
* Copyright (c) 2022 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/collective/socket.h>
#include <cerrno> // EADDRNOTAVAIL
#include <fstream> // ifstream
#include <system_error> // std::error_code, std::system_category
#include "../helpers.h"
namespace xgboost {
namespace collective {
TEST(Socket, Basic) {
system::SocketStartup();
SockAddress addr{SockAddrV6::Loopback()};
ASSERT_TRUE(addr.IsV6());
addr = SockAddress{SockAddrV4::Loopback()};
ASSERT_TRUE(addr.IsV4());
std::string msg{"Skipping IPv6 test"};
auto run_test = [msg](SockDomain domain) {
auto server = TCPSocket::Create(domain);
ASSERT_EQ(server.Domain(), domain);
auto port = server.BindHost();
server.Listen();
TCPSocket client;
if (domain == SockDomain::kV4) {
auto const& addr = SockAddrV4::Loopback().Addr();
ASSERT_EQ(Connect(MakeSockAddress(StringView{addr}, port), &client), std::errc{});
} else {
auto const& addr = SockAddrV6::Loopback().Addr();
auto rc = Connect(MakeSockAddress(StringView{addr}, port), &client);
// some environment (docker) has restricted network configuration.
if (rc == std::error_code{EADDRNOTAVAIL, std::system_category()}) {
GTEST_SKIP_(msg.c_str());
}
ASSERT_EQ(rc, std::errc{});
}
ASSERT_EQ(client.Domain(), domain);
auto accepted = server.Accept();
StringView msg{"Hello world."};
accepted.Send(msg);
std::string str;
client.Recv(&str);
ASSERT_EQ(StringView{str}, msg);
};
run_test(SockDomain::kV4);
std::string path{"/sys/module/ipv6/parameters/disable"};
if (FileExists(path)) {
std::ifstream fin(path);
if (!fin) {
GTEST_SKIP_(msg.c_str());
}
std::string s_value;
fin >> s_value;
auto value = std::stoi(s_value);
if (value != 0) {
GTEST_SKIP_(msg.c_str());
}
} else {
GTEST_SKIP_(msg.c_str());
}
run_test(SockDomain::kV6);
system::SocketFinalize();
}
} // namespace collective
} // namespace xgboost

View File

@ -1,7 +1,6 @@
/*!
* Copyright (c) 2022 by XGBoost Contributors
*/
#ifndef XGBOOST_TESTS_CPP_FILESYSTEM_H
#define XGBOOST_TESTS_CPP_FILESYSTEM_H

View File

@ -1,8 +1,12 @@
/*!
* Copyright 2021-2022, XGBoost contributors.
*/
#ifndef XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
#define XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
#include <xgboost/tree_model.h>
#include <vector>
#include "../../../src/tree/hist/expand_entry.h"
namespace xgboost {
@ -19,3 +23,4 @@ inline void GetSplit(RegTree *tree, float split_value, std::vector<CPUExpandEntr
}
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_

View File

@ -1,34 +1,37 @@
from xgboost import RabitTracker
import xgboost as xgb
import re
import sys
import numpy as np
import pytest
import testing as tm
import numpy as np
import sys
import re
import xgboost as xgb
from xgboost import RabitTracker, testing
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
def test_rabit_tracker():
tracker = RabitTracker(host_ip='127.0.0.1', n_workers=1)
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
tracker.start(1)
worker_env = tracker.worker_envs()
rabit_env = []
for k, v in worker_env.items():
rabit_env.append(f"{k}={v}".encode())
with xgb.rabit.RabitContext(rabit_env):
ret = xgb.rabit.broadcast('test1234', 0)
assert str(ret) == 'test1234'
ret = xgb.rabit.broadcast("test1234", 0)
assert str(ret) == "test1234"
def run_rabit_ops(client, n_workers):
from test_with_dask import _get_client_workers
from xgboost.dask import RabitContext, _get_rabit_args
from xgboost.dask import RabitContext, _get_dask_config, _get_rabit_args
from xgboost import rabit
workers = _get_client_workers(client)
rabit_args = client.sync(_get_rabit_args, len(workers), None, client)
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
assert not rabit.is_distributed()
n_workers_from_dask = len(workers)
assert n_workers == n_workers_from_dask
@ -55,12 +58,26 @@ def run_rabit_ops(client, n_workers):
@pytest.mark.skipif(**tm.no_dask())
def test_rabit_ops():
from distributed import Client, LocalCluster
n_workers = 3
with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
run_rabit_ops(client, n_workers)
@pytest.mark.skipif(**testing.skip_ipv6())
@pytest.mark.skipif(**tm.no_dask())
def test_rabit_ops_ipv6():
import dask
from distributed import Client, LocalCluster
n_workers = 3
with dask.config.set({"xgboost.scheduler_address": "[::1]"}):
with LocalCluster(n_workers=n_workers, host="[::1]") as cluster:
with Client(cluster) as client:
run_rabit_ops(client, n_workers)
def test_rank_assignment() -> None:
from distributed import Client, LocalCluster
from test_with_dask import _get_client_workers