Initial support for IPv6 (#8225)
- Merge rabit socket into XGBoost. - Dask interface support. - Add test to the socket.
This commit is contained in:
@@ -1,65 +1,48 @@
|
||||
/*!
|
||||
* Copyright (c) 2014-2019 by Contributors
|
||||
* Copyright (c) 2014-2022 by XGBoost Contributors
|
||||
* \file socket.h
|
||||
* \brief this file aims to provide a wrapper of sockets
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_INTERNAL_SOCKET_H_
|
||||
#define RABIT_INTERNAL_SOCKET_H_
|
||||
#include "xgboost/collective/socket.h"
|
||||
|
||||
#if defined(_WIN32)
|
||||
#include <winsock2.h>
|
||||
#include <ws2tcpip.h>
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#pragma comment(lib, "Ws2_32.lib")
|
||||
#endif // _MSC_VER
|
||||
|
||||
#else
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <fcntl.h>
|
||||
#include <netdb.h>
|
||||
#include <cerrno>
|
||||
#include <unistd.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <netinet/in.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/ioctl.h>
|
||||
#include <sys/socket.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#if defined(__sun) || defined(sun)
|
||||
#include <sys/sockio.h>
|
||||
#endif // defined(__sun) || defined(sun)
|
||||
#include <cerrno>
|
||||
|
||||
#endif // defined(_WIN32)
|
||||
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <chrono>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#if defined(_WIN32) && !defined(__MINGW32__)
|
||||
typedef int ssize_t;
|
||||
#endif // defined(_WIN32) || defined(__MINGW32__)
|
||||
|
||||
#if defined(_WIN32)
|
||||
using sock_size_t = int;
|
||||
|
||||
#else
|
||||
#if !defined(_WIN32)
|
||||
|
||||
#include <sys/poll.h>
|
||||
|
||||
using SOCKET = int;
|
||||
using sock_size_t = size_t; // NOLINT
|
||||
#endif // defined(_WIN32)
|
||||
#endif // !defined(_WIN32)
|
||||
|
||||
#define IS_MINGW() defined(__MINGW32__)
|
||||
|
||||
#if IS_MINGW()
|
||||
inline void MingWError() {
|
||||
throw dmlc::Error("Distributed training on mingw is not supported.");
|
||||
}
|
||||
#endif // IS_MINGW()
|
||||
|
||||
#if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
|
||||
/*
|
||||
* On later mingw versions poll should be supported (with bugs). See:
|
||||
@@ -88,23 +71,17 @@ typedef struct pollfd {
|
||||
// POLLWRNORM
|
||||
#define POLLOUT 0x0010
|
||||
|
||||
inline const char *inet_ntop(int, const void *, char *, size_t) {
|
||||
MingWError();
|
||||
return nullptr;
|
||||
}
|
||||
#endif // IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
|
||||
|
||||
namespace rabit {
|
||||
namespace utils {
|
||||
|
||||
static constexpr int kInvalidSocket = -1;
|
||||
|
||||
template <typename PollFD>
|
||||
int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) {
|
||||
#if defined(_WIN32)
|
||||
|
||||
#if IS_MINGW()
|
||||
MingWError();
|
||||
xgboost::MingWError();
|
||||
return -1;
|
||||
#else
|
||||
return WSAPoll(pfd, nfds, std::chrono::milliseconds(timeout).count());
|
||||
@@ -115,458 +92,6 @@ int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) {
|
||||
#endif // IS_MINGW()
|
||||
}
|
||||
|
||||
/*! \brief data structure for network address */
|
||||
struct SockAddr {
|
||||
sockaddr_in addr;
|
||||
// constructor
|
||||
SockAddr() = default;
|
||||
SockAddr(const char *url, int port) {
|
||||
this->Set(url, port);
|
||||
}
|
||||
inline static std::string GetHostName() {
|
||||
std::string buf; buf.resize(256);
|
||||
#if !IS_MINGW()
|
||||
utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
|
||||
#endif // IS_MINGW()
|
||||
return std::string(buf.c_str());
|
||||
}
|
||||
/*!
|
||||
* \brief set the address
|
||||
* \param url the url of the address
|
||||
* \param port the port of address
|
||||
*/
|
||||
inline void Set(const char *host, int port) {
|
||||
#if !IS_MINGW()
|
||||
addrinfo hints;
|
||||
memset(&hints, 0, sizeof(hints));
|
||||
hints.ai_family = AF_INET;
|
||||
hints.ai_protocol = SOCK_STREAM;
|
||||
addrinfo *res = nullptr;
|
||||
int sig = getaddrinfo(host, nullptr, &hints, &res);
|
||||
Check(sig == 0 && res != nullptr, "cannot obtain address of %s", host);
|
||||
Check(res->ai_family == AF_INET, "Does not support IPv6");
|
||||
memcpy(&addr, res->ai_addr, res->ai_addrlen);
|
||||
addr.sin_port = htons(port);
|
||||
freeaddrinfo(res);
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*! \brief return port of the address*/
|
||||
inline int Port() const {
|
||||
return ntohs(addr.sin_port);
|
||||
}
|
||||
/*! \return a string representation of the address */
|
||||
inline std::string AddrStr() const {
|
||||
std::string buf; buf.resize(256);
|
||||
#ifdef _WIN32
|
||||
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
|
||||
&buf[0], buf.length());
|
||||
#else
|
||||
const char *s = inet_ntop(AF_INET, &addr.sin_addr,
|
||||
&buf[0], buf.length());
|
||||
#endif // _WIN32
|
||||
Assert(s != nullptr, "cannot decode address");
|
||||
return std::string(s);
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief base class containing common operations of TCP and UDP sockets
|
||||
*/
|
||||
class Socket {
|
||||
public:
|
||||
/*! \brief the file descriptor of socket */
|
||||
SOCKET sockfd;
|
||||
// default conversion to int
|
||||
operator SOCKET() const { // NOLINT
|
||||
return sockfd;
|
||||
}
|
||||
/*!
|
||||
* \return last error of socket operation
|
||||
*/
|
||||
inline static int GetLastError() {
|
||||
#ifdef _WIN32
|
||||
|
||||
#if IS_MINGW()
|
||||
MingWError();
|
||||
return -1;
|
||||
#else
|
||||
return WSAGetLastError();
|
||||
#endif // IS_MINGW()
|
||||
|
||||
#else
|
||||
return errno;
|
||||
#endif // _WIN32
|
||||
}
|
||||
/*! \return whether last error was would block */
|
||||
inline static bool LastErrorWouldBlock() {
|
||||
int errsv = GetLastError();
|
||||
#ifdef _WIN32
|
||||
return errsv == WSAEWOULDBLOCK;
|
||||
#else
|
||||
return errsv == EAGAIN || errsv == EWOULDBLOCK;
|
||||
#endif // _WIN32
|
||||
}
|
||||
/*!
|
||||
* \brief start up the socket module
|
||||
* call this before using the sockets
|
||||
*/
|
||||
inline static void Startup() {
|
||||
#ifdef _WIN32
|
||||
#if !IS_MINGW()
|
||||
WSADATA wsa_data;
|
||||
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
|
||||
Socket::Error("Startup");
|
||||
}
|
||||
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
|
||||
WSACleanup();
|
||||
utils::Error("Could not find a usable version of Winsock.dll\n");
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
#endif // _WIN32
|
||||
}
|
||||
/*!
|
||||
* \brief shutdown the socket module after use, all sockets need to be closed
|
||||
*/
|
||||
inline static void Finalize() {
|
||||
#ifdef _WIN32
|
||||
#if !IS_MINGW()
|
||||
WSACleanup();
|
||||
#endif // !IS_MINGW()
|
||||
#endif // _WIN32
|
||||
}
|
||||
/*!
|
||||
* \brief set this socket to use non-blocking mode
|
||||
* \param non_block whether set it to be non-block, if it is false
|
||||
* it will set it back to block mode
|
||||
*/
|
||||
inline void SetNonBlock(bool non_block) {
|
||||
#ifdef _WIN32
|
||||
#if !IS_MINGW()
|
||||
u_long mode = non_block ? 1 : 0;
|
||||
if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
|
||||
Socket::Error("SetNonBlock");
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
#else
|
||||
int flag = fcntl(sockfd, F_GETFL, 0);
|
||||
if (flag == -1) {
|
||||
Socket::Error("SetNonBlock-1");
|
||||
}
|
||||
if (non_block) {
|
||||
flag |= O_NONBLOCK;
|
||||
} else {
|
||||
flag &= ~O_NONBLOCK;
|
||||
}
|
||||
if (fcntl(sockfd, F_SETFL, flag) == -1) {
|
||||
Socket::Error("SetNonBlock-2");
|
||||
}
|
||||
#endif // _WIN32
|
||||
}
|
||||
/*!
|
||||
* \brief bind the socket to an address
|
||||
* \param addr
|
||||
*/
|
||||
inline void Bind(const SockAddr &addr) {
|
||||
#if !IS_MINGW()
|
||||
if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
|
||||
sizeof(addr.addr)) == -1) {
|
||||
Socket::Error("Bind");
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief try bind the socket to host, from start_port to end_port
|
||||
* \param start_port starting port number to try
|
||||
* \param end_port ending port number to try
|
||||
* \return the port successfully bind to, return -1 if failed to bind any port
|
||||
*/
|
||||
inline int TryBindHost(int start_port, int end_port) {
|
||||
// TODO(tqchen) add prefix check
|
||||
#if !IS_MINGW()
|
||||
for (int port = start_port; port < end_port; ++port) {
|
||||
SockAddr addr("0.0.0.0", port);
|
||||
if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
|
||||
sizeof(addr.addr)) == 0) {
|
||||
return port;
|
||||
}
|
||||
#if defined(_WIN32)
|
||||
if (WSAGetLastError() != WSAEADDRINUSE) {
|
||||
Socket::Error("TryBindHost");
|
||||
}
|
||||
#else
|
||||
if (errno != EADDRINUSE) {
|
||||
Socket::Error("TryBindHost");
|
||||
}
|
||||
#endif // defined(_WIN32)
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
return -1;
|
||||
}
|
||||
/*! \brief get last error code if any */
|
||||
inline int GetSockError() const {
|
||||
int error = 0;
|
||||
socklen_t len = sizeof(error);
|
||||
#if !IS_MINGW()
|
||||
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
|
||||
reinterpret_cast<char *>(&error), &len) != 0) {
|
||||
Error("GetSockError");
|
||||
}
|
||||
#else
|
||||
// undefined reference to `_imp__getsockopt@20'
|
||||
MingWError();
|
||||
#endif // !IS_MINGW()
|
||||
return error;
|
||||
}
|
||||
/*! \brief check if anything bad happens */
|
||||
inline bool BadSocket() const {
|
||||
if (IsClosed()) return true;
|
||||
int err = GetSockError();
|
||||
if (err == EBADF || err == EINTR) return true;
|
||||
return false;
|
||||
}
|
||||
/*! \brief check if socket is already closed */
|
||||
inline bool IsClosed() const {
|
||||
return sockfd == kInvalidSocket;
|
||||
}
|
||||
/*! \brief close the socket */
|
||||
inline void Close() {
|
||||
if (sockfd != kInvalidSocket) {
|
||||
#ifdef _WIN32
|
||||
#if !IS_MINGW()
|
||||
closesocket(sockfd);
|
||||
#endif // !IS_MINGW()
|
||||
#else
|
||||
close(sockfd);
|
||||
#endif
|
||||
sockfd = kInvalidSocket;
|
||||
} else {
|
||||
Error("Socket::Close double close the socket or close without create");
|
||||
}
|
||||
}
|
||||
// report an socket error
|
||||
inline static void Error(const char *msg) {
|
||||
int errsv = GetLastError();
|
||||
#ifdef _WIN32
|
||||
utils::Error("Socket %s Error:WSAError-code=%d", msg, errsv);
|
||||
#else
|
||||
utils::Error("Socket %s Error:%s", msg, strerror(errsv));
|
||||
#endif
|
||||
}
|
||||
|
||||
protected:
|
||||
explicit Socket(SOCKET sockfd) : sockfd(sockfd) {
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief a wrapper of TCP socket that hopefully be cross platform
|
||||
*/
|
||||
class TCPSocket : public Socket{
|
||||
public:
|
||||
// constructor
|
||||
TCPSocket() : Socket(kInvalidSocket) {
|
||||
}
|
||||
explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) {
|
||||
}
|
||||
/*!
|
||||
* \brief enable/disable TCP keepalive
|
||||
* \param keepalive whether to set the keep alive option on
|
||||
*/
|
||||
void SetKeepAlive(bool keepalive) {
|
||||
#if !IS_MINGW()
|
||||
int opt = static_cast<int>(keepalive);
|
||||
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE,
|
||||
reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
|
||||
Socket::Error("SetKeepAlive");
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
inline void SetLinger(int timeout = 0) {
|
||||
#if !IS_MINGW()
|
||||
struct linger sl;
|
||||
sl.l_onoff = 1; /* non-zero value enables linger option in kernel */
|
||||
sl.l_linger = timeout; /* timeout interval in seconds */
|
||||
if (setsockopt(sockfd, SOL_SOCKET, SO_LINGER, reinterpret_cast<char*>(&sl), sizeof(sl)) == -1) {
|
||||
Socket::Error("SO_LINGER");
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief create the socket, call this before using socket
|
||||
* \param af domain
|
||||
*/
|
||||
inline void Create(int af = PF_INET) {
|
||||
#if !IS_MINGW()
|
||||
sockfd = socket(af, SOCK_STREAM, 0);
|
||||
if (sockfd == kInvalidSocket) {
|
||||
Socket::Error("Create");
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief perform listen of the socket
|
||||
* \param backlog backlog parameter
|
||||
*/
|
||||
inline void Listen(int backlog = 16) {
|
||||
#if !IS_MINGW()
|
||||
listen(sockfd, backlog);
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*! \brief get a new connection */
|
||||
TCPSocket Accept() {
|
||||
#if !IS_MINGW()
|
||||
SOCKET newfd = accept(sockfd, nullptr, nullptr);
|
||||
if (newfd == kInvalidSocket) {
|
||||
Socket::Error("Accept");
|
||||
}
|
||||
return TCPSocket(newfd);
|
||||
#else
|
||||
return TCPSocket();
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief decide whether the socket is at OOB mark
|
||||
* \return 1 if at mark, 0 if not, -1 if an error occured
|
||||
*/
|
||||
inline int AtMark() const {
|
||||
#if !IS_MINGW()
|
||||
|
||||
#ifdef _WIN32
|
||||
unsigned long atmark; // NOLINT(*)
|
||||
if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
|
||||
#else
|
||||
int atmark;
|
||||
if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1;
|
||||
#endif // _WIN32
|
||||
|
||||
return static_cast<int>(atmark);
|
||||
|
||||
#else
|
||||
return -1;
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief connect to an address
|
||||
* \param addr the address to connect to
|
||||
* \return whether connect is successful
|
||||
*/
|
||||
inline bool Connect(const SockAddr &addr) {
|
||||
#if !IS_MINGW()
|
||||
return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
|
||||
sizeof(addr.addr)) == 0;
|
||||
#else
|
||||
return false;
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief send data using the socket
|
||||
* \param buf the pointer to the buffer
|
||||
* \param len the size of the buffer
|
||||
* \param flags extra flags
|
||||
* \return size of data actually sent
|
||||
* return -1 if error occurs
|
||||
*/
|
||||
inline ssize_t Send(const void *buf_, size_t len, int flag = 0) {
|
||||
const char *buf = reinterpret_cast<const char*>(buf_);
|
||||
#if !IS_MINGW()
|
||||
return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
|
||||
#else
|
||||
return 0;
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief receive data using the socket
|
||||
* \param buf_ the pointer to the buffer
|
||||
* \param len the size of the buffer
|
||||
* \param flags extra flags
|
||||
* \return size of data actually received
|
||||
* return -1 if error occurs
|
||||
*/
|
||||
inline ssize_t Recv(void *buf_, size_t len, int flags = 0) {
|
||||
char *buf = reinterpret_cast<char*>(buf_);
|
||||
#if !IS_MINGW()
|
||||
return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
|
||||
#else
|
||||
return 0;
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief peform block write that will attempt to send all data out
|
||||
* can still return smaller than request when error occurs
|
||||
* \param buf the pointer to the buffer
|
||||
* \param len the size of the buffer
|
||||
* \return size of data actually sent
|
||||
*/
|
||||
inline size_t SendAll(const void *buf_, size_t len) {
|
||||
const char *buf = reinterpret_cast<const char*>(buf_);
|
||||
size_t ndone = 0;
|
||||
#if !IS_MINGW()
|
||||
while (ndone < len) {
|
||||
ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0);
|
||||
if (ret == -1) {
|
||||
if (LastErrorWouldBlock()) return ndone;
|
||||
Socket::Error("SendAll");
|
||||
}
|
||||
buf += ret;
|
||||
ndone += ret;
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
return ndone;
|
||||
}
|
||||
/*!
|
||||
* \brief peforma block read that will attempt to read all data
|
||||
* can still return smaller than request when error occurs
|
||||
* \param buf_ the buffer pointer
|
||||
* \param len length of data to recv
|
||||
* \return size of data actually sent
|
||||
*/
|
||||
inline size_t RecvAll(void *buf_, size_t len) {
|
||||
char *buf = reinterpret_cast<char*>(buf_);
|
||||
size_t ndone = 0;
|
||||
#if !IS_MINGW()
|
||||
while (ndone < len) {
|
||||
ssize_t ret = recv(sockfd, buf,
|
||||
static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
|
||||
if (ret == -1) {
|
||||
if (LastErrorWouldBlock()) return ndone;
|
||||
Socket::Error("RecvAll");
|
||||
}
|
||||
if (ret == 0) return ndone;
|
||||
buf += ret;
|
||||
ndone += ret;
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
return ndone;
|
||||
}
|
||||
/*!
|
||||
* \brief send a string over network
|
||||
* \param str the string to be sent
|
||||
*/
|
||||
inline void SendStr(const std::string &str) {
|
||||
int len = static_cast<int>(str.length());
|
||||
utils::Assert(this->SendAll(&len, sizeof(len)) == sizeof(len),
|
||||
"error during send SendStr");
|
||||
if (len != 0) {
|
||||
utils::Assert(this->SendAll(str.c_str(), str.length()) == str.length(),
|
||||
"error during send SendStr");
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief recv a string from network
|
||||
* \param out_str the string to receive
|
||||
*/
|
||||
inline void RecvStr(std::string *out_str) {
|
||||
int len;
|
||||
utils::Assert(this->RecvAll(&len, sizeof(len)) == sizeof(len),
|
||||
"error during send RecvStr");
|
||||
out_str->resize(len);
|
||||
if (len != 0) {
|
||||
utils::Assert(this->RecvAll(&(*out_str)[0], len) == out_str->length(),
|
||||
"error during send SendStr");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief helper data structure to perform poll */
|
||||
struct PollHelper {
|
||||
public:
|
||||
@@ -579,6 +104,8 @@ struct PollHelper {
|
||||
pfd.fd = fd;
|
||||
pfd.events |= POLLIN;
|
||||
}
|
||||
void WatchRead(xgboost::collective::TCPSocket const &socket) { this->WatchRead(socket.Handle()); }
|
||||
|
||||
/*!
|
||||
* \brief add file descriptor to watch for write
|
||||
* \param fd file descriptor to be watched
|
||||
@@ -588,6 +115,10 @@ struct PollHelper {
|
||||
pfd.fd = fd;
|
||||
pfd.events |= POLLOUT;
|
||||
}
|
||||
void WatchWrite(xgboost::collective::TCPSocket const &socket) {
|
||||
this->WatchWrite(socket.Handle());
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief add file descriptor to watch for exception
|
||||
* \param fd file descriptor to be watched
|
||||
@@ -597,6 +128,9 @@ struct PollHelper {
|
||||
pfd.fd = fd;
|
||||
pfd.events |= POLLPRI;
|
||||
}
|
||||
void WatchException(xgboost::collective::TCPSocket const &socket) {
|
||||
this->WatchException(socket.Handle());
|
||||
}
|
||||
/*!
|
||||
* \brief Check if the descriptor is ready for read
|
||||
* \param fd file descriptor to check status
|
||||
@@ -605,6 +139,10 @@ struct PollHelper {
|
||||
const auto& pfd = fds.find(fd);
|
||||
return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
|
||||
}
|
||||
bool CheckRead(xgboost::collective::TCPSocket const &socket) const {
|
||||
return this->CheckRead(socket.Handle());
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Check if the descriptor is ready for write
|
||||
* \param fd file descriptor to check status
|
||||
@@ -613,7 +151,9 @@ struct PollHelper {
|
||||
const auto& pfd = fds.find(fd);
|
||||
return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
|
||||
}
|
||||
|
||||
bool CheckWrite(xgboost::collective::TCPSocket const &socket) const {
|
||||
return this->CheckWrite(socket.Handle());
|
||||
}
|
||||
/*!
|
||||
* \brief perform poll on the set defined, read, write, exception
|
||||
* \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
|
||||
@@ -629,7 +169,7 @@ struct PollHelper {
|
||||
if (ret == 0) {
|
||||
LOG(FATAL) << "Poll timeout";
|
||||
} else if (ret < 0) {
|
||||
Socket::Error("Poll");
|
||||
LOG(FATAL) << "Failed to poll.";
|
||||
} else {
|
||||
for (auto& pfd : fdset) {
|
||||
auto revents = pfd.revents & pfd.events;
|
||||
|
||||
@@ -8,15 +8,17 @@
|
||||
#define RABIT_INTERNAL_UTILS_H_
|
||||
|
||||
#include <rabit/base.h>
|
||||
#include <cstring>
|
||||
|
||||
#include <cstdarg>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "dmlc/io.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include <cstdarg>
|
||||
|
||||
#if !defined(__GNUC__) || defined(__FreeBSD__)
|
||||
#define fopen64 std::fopen
|
||||
|
||||
@@ -27,8 +27,6 @@ AllreduceBase::AllreduceBase() {
|
||||
tracker_uri = "NULL";
|
||||
tracker_port = 9000;
|
||||
host_uri = "";
|
||||
slave_port = 9010;
|
||||
nport_trial = 1000;
|
||||
rank = 0;
|
||||
world_size = -1;
|
||||
connect_retry = 5;
|
||||
@@ -114,16 +112,16 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
|
||||
this->rank = -1;
|
||||
//---------------------
|
||||
// start socket
|
||||
utils::Socket::Startup();
|
||||
xgboost::system::SocketStartup();
|
||||
utils::Assert(all_links.size() == 0, "can only call Init once");
|
||||
this->host_uri = utils::SockAddr::GetHostName();
|
||||
this->host_uri = xgboost::collective::GetHostName();
|
||||
// get information from tracker
|
||||
return this->ReConnectLinks();
|
||||
}
|
||||
|
||||
bool AllreduceBase::Shutdown() {
|
||||
try {
|
||||
for (auto & all_link : all_links) {
|
||||
for (auto &all_link : all_links) {
|
||||
if (!all_link.sock.IsClosed()) {
|
||||
all_link.sock.Close();
|
||||
}
|
||||
@@ -133,12 +131,12 @@ bool AllreduceBase::Shutdown() {
|
||||
|
||||
if (tracker_uri == "NULL") return true;
|
||||
// notify tracker rank i have shutdown
|
||||
utils::TCPSocket tracker = this->ConnectTracker();
|
||||
tracker.SendStr(std::string("shutdown"));
|
||||
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
|
||||
tracker.Send(xgboost::StringView{"shutdown"});
|
||||
tracker.Close();
|
||||
utils::TCPSocket::Finalize();
|
||||
xgboost::system::SocketFinalize();
|
||||
return true;
|
||||
} catch (const std::exception& e) {
|
||||
} catch (std::exception const &e) {
|
||||
LOG(WARNING) << "Failed to shutdown due to" << e.what();
|
||||
return false;
|
||||
}
|
||||
@@ -148,9 +146,9 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
|
||||
if (tracker_uri == "NULL") {
|
||||
utils::Printf("%s", msg.c_str()); return;
|
||||
}
|
||||
utils::TCPSocket tracker = this->ConnectTracker();
|
||||
tracker.SendStr(std::string("print"));
|
||||
tracker.SendStr(msg);
|
||||
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
|
||||
tracker.Send(xgboost::StringView{"print"});
|
||||
tracker.Send(xgboost::StringView{msg});
|
||||
tracker.Close();
|
||||
}
|
||||
|
||||
@@ -227,21 +225,23 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
* \brief initialize connection to the tracker
|
||||
* \return a socket that initializes the connection
|
||||
*/
|
||||
utils::TCPSocket AllreduceBase::ConnectTracker() const {
|
||||
xgboost::collective::TCPSocket AllreduceBase::ConnectTracker() const {
|
||||
int magic = kMagic;
|
||||
// get information from tracker
|
||||
utils::TCPSocket tracker;
|
||||
tracker.Create();
|
||||
xgboost::collective::TCPSocket tracker;
|
||||
|
||||
int retry = 0;
|
||||
do {
|
||||
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
|
||||
auto rc = xgboost::collective::Connect(
|
||||
xgboost::collective::MakeSockAddress(xgboost::StringView{tracker_uri}, tracker_port),
|
||||
&tracker);
|
||||
if (rc != std::errc()) {
|
||||
if (++retry >= connect_retry) {
|
||||
LOG(WARNING) << "Connect to (failed): [" << tracker_uri << "]\n";
|
||||
utils::Socket::Error("Connect");
|
||||
LOG(FATAL) << "Connecting to (failed): [" << tracker_uri << "]\n" << rc.message();
|
||||
} else {
|
||||
LOG(WARNING) << "Retry connect to ip(retry time " << retry << "): [" << tracker_uri << "]\n";
|
||||
#if defined(_MSC_VER) || defined (__MINGW32__)
|
||||
LOG(WARNING) << rc.message() << "\nRetry connecting to IP(retry time: " << retry << "): ["
|
||||
<< tracker_uri << "]";
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
Sleep(retry << 1);
|
||||
#else
|
||||
sleep(retry << 1);
|
||||
@@ -253,16 +253,13 @@ utils::TCPSocket AllreduceBase::ConnectTracker() const {
|
||||
} while (true);
|
||||
|
||||
using utils::Assert;
|
||||
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
|
||||
"ReConnectLink failure 1");
|
||||
Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic),
|
||||
"ReConnectLink failure 2");
|
||||
CHECK_EQ(tracker.SendAll(&magic, sizeof(magic)), sizeof(magic));
|
||||
CHECK_EQ(tracker.RecvAll(&magic, sizeof(magic)), sizeof(magic));
|
||||
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
|
||||
Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
||||
"ReConnectLink failure 3");
|
||||
Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
|
||||
Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size),
|
||||
"ReConnectLink failure 3");
|
||||
tracker.SendStr(task_id);
|
||||
CHECK_EQ(tracker.Send(xgboost::StringView{task_id}), task_id.size());
|
||||
return tracker;
|
||||
}
|
||||
/*!
|
||||
@@ -272,12 +269,15 @@ utils::TCPSocket AllreduceBase::ConnectTracker() const {
|
||||
bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
// single node mode
|
||||
if (tracker_uri == "NULL") {
|
||||
rank = 0; world_size = 1; return true;
|
||||
rank = 0;
|
||||
world_size = 1;
|
||||
return true;
|
||||
}
|
||||
|
||||
try {
|
||||
utils::TCPSocket tracker = this->ConnectTracker();
|
||||
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
|
||||
LOG(INFO) << "task " << task_id << " connected to the tracker";
|
||||
tracker.SendStr(std::string(cmd));
|
||||
tracker.Send(xgboost::StringView{cmd});
|
||||
|
||||
// the rank of previous link, next link in ring
|
||||
int prev_rank, next_rank;
|
||||
@@ -315,13 +315,9 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
|
||||
"ReConnectLink failure 4");
|
||||
|
||||
utils::TCPSocket sock_listen;
|
||||
if (!sock_listen.IsClosed()) {
|
||||
sock_listen.Close();
|
||||
}
|
||||
auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())};
|
||||
// create listening socket
|
||||
sock_listen.Create();
|
||||
int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial);
|
||||
int port = sock_listen.BindHost();
|
||||
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
|
||||
sock_listen.Listen();
|
||||
|
||||
@@ -338,29 +334,27 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
}
|
||||
}
|
||||
int ngood = static_cast<int>(good_link.size());
|
||||
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
|
||||
"ReConnectLink failure 5");
|
||||
for (int & i : good_link) {
|
||||
Assert(tracker.SendAll(&i, sizeof(i)) == \
|
||||
sizeof(i), "ReConnectLink failure 6");
|
||||
// tracker construct goodset
|
||||
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood), "ReConnectLink failure 5");
|
||||
for (int &i : good_link) {
|
||||
Assert(tracker.SendAll(&i, sizeof(i)) == sizeof(i), "ReConnectLink failure 6");
|
||||
}
|
||||
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
||||
"ReConnectLink failure 7");
|
||||
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \
|
||||
sizeof(num_accept), "ReConnectLink failure 8");
|
||||
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept),
|
||||
"ReConnectLink failure 8");
|
||||
num_error = 0;
|
||||
for (int i = 0; i < num_conn; ++i) {
|
||||
LinkRecord r;
|
||||
int hport, hrank;
|
||||
std::string hname;
|
||||
tracker.RecvStr(&hname);
|
||||
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport),
|
||||
"ReConnectLink failure 9");
|
||||
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank),
|
||||
"ReConnectLink failure 10");
|
||||
tracker.Recv(&hname);
|
||||
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
|
||||
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
|
||||
|
||||
r.sock.Create();
|
||||
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
|
||||
if (xgboost::collective::Connect(
|
||||
xgboost::collective::MakeSockAddress(xgboost::StringView{hname}, hport), &r.sock) !=
|
||||
std::errc{}) {
|
||||
num_error += 1;
|
||||
r.sock.Close();
|
||||
continue;
|
||||
@@ -376,12 +370,12 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
if (all_link.rank == hrank) {
|
||||
Assert(all_link.sock.IsClosed(),
|
||||
"Override a link that is active");
|
||||
all_link.sock = r.sock;
|
||||
all_link.sock = std::move(r.sock);
|
||||
match = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!match) all_links.push_back(r);
|
||||
if (!match) all_links.emplace_back(std::move(r));
|
||||
}
|
||||
Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
|
||||
"ReConnectLink failure 14");
|
||||
@@ -404,30 +398,24 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
if (all_link.rank == r.rank) {
|
||||
utils::Assert(all_link.sock.IsClosed(),
|
||||
"Override a link that is active");
|
||||
all_link.sock = r.sock;
|
||||
all_link.sock = std::move(r.sock);
|
||||
match = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!match) all_links.push_back(r);
|
||||
if (!match) all_links.emplace_back(std::move(r));
|
||||
}
|
||||
sock_listen.Close();
|
||||
this->parent_index = -1;
|
||||
// setup tree links and ring structure
|
||||
tree_links.plinks.clear();
|
||||
int tcpNoDelay = 1;
|
||||
for (auto & all_link : all_links) {
|
||||
for (auto &all_link : all_links) {
|
||||
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
|
||||
// set the socket to non-blocking mode, enable TCP keepalive
|
||||
all_link.sock.SetNonBlock(true);
|
||||
all_link.sock.SetKeepAlive(true);
|
||||
all_link.sock.SetNonBlock();
|
||||
all_link.sock.SetKeepAlive();
|
||||
if (rabit_enable_tcp_no_delay) {
|
||||
#if defined(__unix__)
|
||||
setsockopt(all_link.sock, IPPROTO_TCP,
|
||||
TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay));
|
||||
#else
|
||||
LOG(WARNING) << "tcp no delay is not implemented on non unix platforms";
|
||||
#endif
|
||||
all_link.sock.SetNoDelay();
|
||||
}
|
||||
if (tree_neighbors.count(all_link.rank) != 0) {
|
||||
if (all_link.rank == parent_rank) {
|
||||
|
||||
@@ -201,8 +201,8 @@ class AllreduceBase : public IEngine {
|
||||
}
|
||||
};
|
||||
/*! \brief translate errno to return type */
|
||||
inline static ReturnType Errno2Return() {
|
||||
int errsv = utils::Socket::GetLastError();
|
||||
static ReturnType Errno2Return() {
|
||||
int errsv = xgboost::system::LastError();
|
||||
if (errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == 0) return kSuccess;
|
||||
#ifdef _WIN32
|
||||
if (errsv == WSAEWOULDBLOCK) return kSuccess;
|
||||
@@ -215,7 +215,7 @@ class AllreduceBase : public IEngine {
|
||||
struct LinkRecord {
|
||||
public:
|
||||
// socket to get data from/to link
|
||||
utils::TCPSocket sock;
|
||||
xgboost::collective::TCPSocket sock;
|
||||
// rank of the node in this link
|
||||
int rank;
|
||||
// size of data readed from link
|
||||
@@ -329,7 +329,7 @@ class AllreduceBase : public IEngine {
|
||||
* \brief initialize connection to the tracker
|
||||
* \return a socket that initializes the connection
|
||||
*/
|
||||
utils::TCPSocket ConnectTracker() const;
|
||||
xgboost::collective::TCPSocket ConnectTracker() const;
|
||||
/*!
|
||||
* \brief connect to the tracker to fix the the missing links
|
||||
* this function is also used when the engine start up
|
||||
@@ -473,8 +473,6 @@ class AllreduceBase : public IEngine {
|
||||
std::string dmlc_role; // NOLINT
|
||||
// port of tracker address
|
||||
int tracker_port; // NOLINT
|
||||
// port of slave process
|
||||
int slave_port, nport_trial; // NOLINT
|
||||
// reduce buffer size
|
||||
size_t reduce_buffer_size; // NOLINT
|
||||
// reduction method
|
||||
|
||||
Reference in New Issue
Block a user