Enable building rabit on Windows (#6105)
This commit is contained in:
@@ -9,10 +9,13 @@
|
||||
#if defined(_WIN32)
|
||||
#include <winsock2.h>
|
||||
#include <ws2tcpip.h>
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#pragma comment(lib, "Ws2_32.lib")
|
||||
#endif // _MSC_VER
|
||||
|
||||
#else
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <netdb.h>
|
||||
#include <cerrno>
|
||||
@@ -21,31 +24,92 @@
|
||||
#include <netinet/in.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/ioctl.h>
|
||||
|
||||
#endif // defined(_WIN32)
|
||||
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include "utils.h"
|
||||
|
||||
#if defined(_WIN32) || defined(__MINGW32__)
|
||||
#if defined(_WIN32) && !defined(__MINGW32__)
|
||||
typedef int ssize_t;
|
||||
#endif // defined(_WIN32) || defined(__MINGW32__)
|
||||
|
||||
#if defined(_WIN32)
|
||||
typedef int sock_size_t;
|
||||
using sock_size_t = int;
|
||||
|
||||
static inline int poll(struct pollfd *pfd, int nfds,
|
||||
int timeout) { return WSAPoll ( pfd, nfds, timeout ); }
|
||||
#else
|
||||
|
||||
#include <sys/poll.h>
|
||||
using SOCKET = int;
|
||||
using sock_size_t = size_t; // NOLINT
|
||||
const int kInvalidSocket = -1;
|
||||
#endif // defined(_WIN32)
|
||||
|
||||
#define IS_MINGW() defined(__MINGW32__)
|
||||
|
||||
#if IS_MINGW()
|
||||
inline void MingWError() {
|
||||
throw dmlc::Error("Distributed training on mingw is not supported.");
|
||||
}
|
||||
#endif // IS_MINGW()
|
||||
|
||||
#if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
|
||||
/*
|
||||
* On later mingw versions poll should be supported (with bugs). See:
|
||||
* https://stackoverflow.com/a/60623080
|
||||
*
|
||||
* But right now the mingw distributed with R 3.6 doesn't support it.
|
||||
* So we just give a warning and provide dummy implementation to get
|
||||
* compilation passed. Otherwise we will have to provide a stub for
|
||||
* RABIT.
|
||||
*
|
||||
* Even on mingw version that has these structures and flags defined,
|
||||
* functions like `send` and `listen` might have unresolved linkage to
|
||||
* their implementation. So supporting mingw is quite difficult at
|
||||
* the time of writing.
|
||||
*/
|
||||
#pragma message("Distributed training on mingw is not supported.")
|
||||
typedef struct pollfd {
|
||||
SOCKET fd;
|
||||
short events;
|
||||
short revents;
|
||||
} WSAPOLLFD, *PWSAPOLLFD, *LPWSAPOLLFD;
|
||||
|
||||
// POLLRDNORM | POLLRDBAND
|
||||
#define POLLIN (0x0100 | 0x0200)
|
||||
#define POLLPRI 0x0400
|
||||
// POLLWRNORM
|
||||
#define POLLOUT 0x0010
|
||||
|
||||
inline const char *inet_ntop(int, const void *, char *, size_t) {
|
||||
MingWError();
|
||||
return nullptr;
|
||||
}
|
||||
#endif // IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
|
||||
|
||||
namespace rabit {
|
||||
namespace utils {
|
||||
|
||||
static constexpr int kInvalidSocket = -1;
|
||||
|
||||
template <typename PollFD>
|
||||
int PollImpl(PollFD *pfd, int nfds, int timeout) {
|
||||
#if defined(_WIN32)
|
||||
|
||||
#if IS_MINGW()
|
||||
MingWError();
|
||||
return -1;
|
||||
#else
|
||||
return WSAPoll(pfd, nfds, timeout);
|
||||
#endif // IS_MINGW()
|
||||
|
||||
#else
|
||||
return poll(pfd, nfds, timeout);
|
||||
#endif // IS_MINGW()
|
||||
}
|
||||
|
||||
/*! \brief data structure for network address */
|
||||
struct SockAddr {
|
||||
sockaddr_in addr;
|
||||
@@ -56,7 +120,9 @@ struct SockAddr {
|
||||
}
|
||||
inline static std::string GetHostName() {
|
||||
std::string buf; buf.resize(256);
|
||||
#if !IS_MINGW()
|
||||
utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
|
||||
#endif // IS_MINGW()
|
||||
return std::string(buf.c_str());
|
||||
}
|
||||
/*!
|
||||
@@ -65,6 +131,7 @@ struct SockAddr {
|
||||
* \param port the port of address
|
||||
*/
|
||||
inline void Set(const char *host, int port) {
|
||||
#if !IS_MINGW()
|
||||
addrinfo hints;
|
||||
memset(&hints, 0, sizeof(hints));
|
||||
hints.ai_family = AF_INET;
|
||||
@@ -76,6 +143,7 @@ struct SockAddr {
|
||||
memcpy(&addr, res->ai_addr, res->ai_addrlen);
|
||||
addr.sin_port = htons(port);
|
||||
freeaddrinfo(res);
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*! \brief return port of the address*/
|
||||
inline int Port() const {
|
||||
@@ -112,7 +180,14 @@ class Socket {
|
||||
*/
|
||||
inline static int GetLastError() {
|
||||
#ifdef _WIN32
|
||||
|
||||
#if IS_MINGW()
|
||||
MingWError();
|
||||
return -1;
|
||||
#else
|
||||
return WSAGetLastError();
|
||||
#endif // IS_MINGW()
|
||||
|
||||
#else
|
||||
return errno;
|
||||
#endif // _WIN32
|
||||
@@ -132,14 +207,16 @@ class Socket {
|
||||
*/
|
||||
inline static void Startup() {
|
||||
#ifdef _WIN32
|
||||
#if !IS_MINGW()
|
||||
WSADATA wsa_data;
|
||||
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
|
||||
Socket::Error("Startup");
|
||||
Socket::Error("Startup");
|
||||
}
|
||||
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
|
||||
WSACleanup();
|
||||
utils::Error("Could not find a usable version of Winsock.dll\n");
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
#endif // _WIN32
|
||||
}
|
||||
/*!
|
||||
@@ -147,7 +224,9 @@ class Socket {
|
||||
*/
|
||||
inline static void Finalize() {
|
||||
#ifdef _WIN32
|
||||
#if !IS_MINGW()
|
||||
WSACleanup();
|
||||
#endif // !IS_MINGW()
|
||||
#endif // _WIN32
|
||||
}
|
||||
/*!
|
||||
@@ -157,10 +236,12 @@ class Socket {
|
||||
*/
|
||||
inline void SetNonBlock(bool non_block) {
|
||||
#ifdef _WIN32
|
||||
#if !IS_MINGW()
|
||||
u_long mode = non_block ? 1 : 0;
|
||||
if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
|
||||
Socket::Error("SetNonBlock");
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
#else
|
||||
int flag = fcntl(sockfd, F_GETFL, 0);
|
||||
if (flag == -1) {
|
||||
@@ -181,10 +262,12 @@ class Socket {
|
||||
* \param addr
|
||||
*/
|
||||
inline void Bind(const SockAddr &addr) {
|
||||
#if !IS_MINGW()
|
||||
if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
|
||||
sizeof(addr.addr)) == -1) {
|
||||
Socket::Error("Bind");
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief try bind the socket to host, from start_port to end_port
|
||||
@@ -194,6 +277,7 @@ class Socket {
|
||||
*/
|
||||
inline int TryBindHost(int start_port, int end_port) {
|
||||
// TODO(tqchen) add prefix check
|
||||
#if !IS_MINGW()
|
||||
for (int port = start_port; port < end_port; ++port) {
|
||||
SockAddr addr("0.0.0.0", port);
|
||||
if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
|
||||
@@ -210,17 +294,22 @@ class Socket {
|
||||
}
|
||||
#endif // defined(_WIN32)
|
||||
}
|
||||
|
||||
#endif // !IS_MINGW()
|
||||
return -1;
|
||||
}
|
||||
/*! \brief get last error code if any */
|
||||
inline int GetSockError() const {
|
||||
int error = 0;
|
||||
socklen_t len = sizeof(error);
|
||||
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
|
||||
reinterpret_cast<char*>(&error), &len) != 0) {
|
||||
#if !IS_MINGW()
|
||||
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
|
||||
reinterpret_cast<char *>(&error), &len) != 0) {
|
||||
Error("GetSockError");
|
||||
}
|
||||
#else
|
||||
// undefined reference to `_imp__getsockopt@20'
|
||||
MingWError();
|
||||
#endif // !IS_MINGW()
|
||||
return error;
|
||||
}
|
||||
/*! \brief check if anything bad happens */
|
||||
@@ -238,7 +327,9 @@ class Socket {
|
||||
inline void Close() {
|
||||
if (sockfd != kInvalidSocket) {
|
||||
#ifdef _WIN32
|
||||
#if !IS_MINGW()
|
||||
closesocket(sockfd);
|
||||
#endif // !IS_MINGW()
|
||||
#else
|
||||
close(sockfd);
|
||||
#endif
|
||||
@@ -277,50 +368,64 @@ class TCPSocket : public Socket{
|
||||
* \param keepalive whether to set the keep alive option on
|
||||
*/
|
||||
void SetKeepAlive(bool keepalive) {
|
||||
#if !IS_MINGW()
|
||||
int opt = static_cast<int>(keepalive);
|
||||
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE,
|
||||
reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
|
||||
Socket::Error("SetKeepAlive");
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
inline void SetLinger(int timeout = 0) {
|
||||
#if !IS_MINGW()
|
||||
struct linger sl;
|
||||
sl.l_onoff = 1; /* non-zero value enables linger option in kernel */
|
||||
sl.l_linger = timeout; /* timeout interval in seconds */
|
||||
if (setsockopt(sockfd, SOL_SOCKET, SO_LINGER, reinterpret_cast<char*>(&sl), sizeof(sl)) == -1) {
|
||||
Socket::Error("SO_LINGER");
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief create the socket, call this before using socket
|
||||
* \param af domain
|
||||
*/
|
||||
inline void Create(int af = PF_INET) {
|
||||
#if !IS_MINGW()
|
||||
sockfd = socket(PF_INET, SOCK_STREAM, 0);
|
||||
if (sockfd == kInvalidSocket) {
|
||||
Socket::Error("Create");
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief perform listen of the socket
|
||||
* \param backlog backlog parameter
|
||||
*/
|
||||
inline void Listen(int backlog = 16) {
|
||||
#if !IS_MINGW()
|
||||
listen(sockfd, backlog);
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*! \brief get a new connection */
|
||||
TCPSocket Accept() {
|
||||
#if !IS_MINGW()
|
||||
SOCKET newfd = accept(sockfd, nullptr, nullptr);
|
||||
if (newfd == kInvalidSocket) {
|
||||
Socket::Error("Accept");
|
||||
}
|
||||
return TCPSocket(newfd);
|
||||
#else
|
||||
return TCPSocket();
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief decide whether the socket is at OOB mark
|
||||
* \return 1 if at mark, 0 if not, -1 if an error occured
|
||||
*/
|
||||
inline int AtMark() const {
|
||||
#if !IS_MINGW()
|
||||
|
||||
#ifdef _WIN32
|
||||
unsigned long atmark; // NOLINT(*)
|
||||
if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
|
||||
@@ -328,7 +433,12 @@ class TCPSocket : public Socket{
|
||||
int atmark;
|
||||
if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1;
|
||||
#endif // _WIN32
|
||||
|
||||
return static_cast<int>(atmark);
|
||||
|
||||
#else
|
||||
return -1;
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief connect to an address
|
||||
@@ -336,8 +446,12 @@ class TCPSocket : public Socket{
|
||||
* \return whether connect is successful
|
||||
*/
|
||||
inline bool Connect(const SockAddr &addr) {
|
||||
#if !IS_MINGW()
|
||||
return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
|
||||
sizeof(addr.addr)) == 0;
|
||||
#else
|
||||
return false;
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief send data using the socket
|
||||
@@ -349,7 +463,11 @@ class TCPSocket : public Socket{
|
||||
*/
|
||||
inline ssize_t Send(const void *buf_, size_t len, int flag = 0) {
|
||||
const char *buf = reinterpret_cast<const char*>(buf_);
|
||||
#if !IS_MINGW()
|
||||
return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
|
||||
#else
|
||||
return 0;
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief receive data using the socket
|
||||
@@ -361,7 +479,11 @@ class TCPSocket : public Socket{
|
||||
*/
|
||||
inline ssize_t Recv(void *buf_, size_t len, int flags = 0) {
|
||||
char *buf = reinterpret_cast<char*>(buf_);
|
||||
#if !IS_MINGW()
|
||||
return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
|
||||
#else
|
||||
return 0;
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief peform block write that will attempt to send all data out
|
||||
@@ -373,6 +495,7 @@ class TCPSocket : public Socket{
|
||||
inline size_t SendAll(const void *buf_, size_t len) {
|
||||
const char *buf = reinterpret_cast<const char*>(buf_);
|
||||
size_t ndone = 0;
|
||||
#if !IS_MINGW()
|
||||
while (ndone < len) {
|
||||
ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0);
|
||||
if (ret == -1) {
|
||||
@@ -382,6 +505,7 @@ class TCPSocket : public Socket{
|
||||
buf += ret;
|
||||
ndone += ret;
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
return ndone;
|
||||
}
|
||||
/*!
|
||||
@@ -394,6 +518,7 @@ class TCPSocket : public Socket{
|
||||
inline size_t RecvAll(void *buf_, size_t len) {
|
||||
char *buf = reinterpret_cast<char*>(buf_);
|
||||
size_t ndone = 0;
|
||||
#if !IS_MINGW()
|
||||
while (ndone < len) {
|
||||
ssize_t ret = recv(sockfd, buf,
|
||||
static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
|
||||
@@ -405,6 +530,7 @@ class TCPSocket : public Socket{
|
||||
buf += ret;
|
||||
ndone += ret;
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
return ndone;
|
||||
}
|
||||
/*!
|
||||
@@ -500,7 +626,7 @@ struct PollHelper {
|
||||
pollfd pfd;
|
||||
pfd.fd = fd;
|
||||
pfd.events = POLLPRI;
|
||||
return poll(&pfd, 1, timeout);
|
||||
return PollImpl(&pfd, 1, timeout);
|
||||
}
|
||||
|
||||
/*!
|
||||
@@ -514,7 +640,7 @@ struct PollHelper {
|
||||
for (auto kv : fds) {
|
||||
fdset.push_back(kv.second);
|
||||
}
|
||||
int ret = poll(fdset.data(), fdset.size(), timeout);
|
||||
int ret = PollImpl(fdset.data(), fdset.size(), timeout);
|
||||
if (ret == -1) {
|
||||
Socket::Error("Poll");
|
||||
} else {
|
||||
@@ -533,4 +659,11 @@ struct PollHelper {
|
||||
};
|
||||
} // namespace utils
|
||||
} // namespace rabit
|
||||
|
||||
#if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
|
||||
#undef POLLIN
|
||||
#undef POLLPRI
|
||||
#undef POLLOUT
|
||||
#endif // IS_MINGW()
|
||||
|
||||
#endif // RABIT_INTERNAL_SOCKET_H_
|
||||
|
||||
@@ -15,10 +15,7 @@
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#include "dmlc/io.h"
|
||||
|
||||
#ifndef RABIT_STRICT_CXX98_
|
||||
#include <cstdarg>
|
||||
#endif // RABIT_STRICT_CXX98_
|
||||
|
||||
#if !defined(__GNUC__) || defined(__FreeBSD__)
|
||||
#define fopen64 std::fopen
|
||||
@@ -71,7 +68,6 @@ inline bool StringToBool(const char* s) {
|
||||
return CompareStringsCaseInsensitive(s, "true") == 0 || atoi(s) != 0;
|
||||
}
|
||||
|
||||
#ifndef RABIT_CUSTOMIZE_MSG_
|
||||
/*!
|
||||
* \brief handling of Assert error, caused by inappropriate input
|
||||
* \param msg error message
|
||||
@@ -89,6 +85,7 @@ inline void HandleCheckError(const char *msg) {
|
||||
fprintf(stderr, "%s, rabit is configured to keep process running\n", msg);
|
||||
throw dmlc::Error(msg);
|
||||
}
|
||||
|
||||
inline void HandlePrint(const char *msg) {
|
||||
printf("%s", msg);
|
||||
}
|
||||
@@ -102,22 +99,7 @@ inline void HandleLogInfo(const char *fmt, ...) {
|
||||
fprintf(stdout, "%s", msg.c_str());
|
||||
fflush(stdout);
|
||||
}
|
||||
#else
|
||||
#ifndef RABIT_STRICT_CXX98_
|
||||
// include declarations, some one must implement this
|
||||
void HandleAssertError(const char *msg);
|
||||
void HandleCheckError(const char *msg);
|
||||
void HandlePrint(const char *msg);
|
||||
#endif // RABIT_STRICT_CXX98_
|
||||
#endif // RABIT_CUSTOMIZE_MSG_
|
||||
#ifdef RABIT_STRICT_CXX98_
|
||||
// these function pointers are to be assigned
|
||||
extern "C" void (*Printf)(const char *fmt, ...);
|
||||
extern "C" int (*SPrintf)(char *buf, size_t size, const char *fmt, ...);
|
||||
extern "C" void (*Assert)(int exp, const char *fmt, ...);
|
||||
extern "C" void (*Check)(int exp, const char *fmt, ...);
|
||||
extern "C" void (*Error)(const char *fmt, ...);
|
||||
#else
|
||||
|
||||
/*! \brief printf, prints messages to the console */
|
||||
inline void Printf(const char *fmt, ...) {
|
||||
std::string msg(kPrintBuffer, '\0');
|
||||
@@ -127,6 +109,7 @@ inline void Printf(const char *fmt, ...) {
|
||||
va_end(args);
|
||||
HandlePrint(msg.c_str());
|
||||
}
|
||||
|
||||
/*! \brief portable version of snprintf */
|
||||
inline int SPrintf(char *buf, size_t size, const char *fmt, ...) {
|
||||
va_list args;
|
||||
@@ -171,7 +154,6 @@ inline void Error(const char *fmt, ...) {
|
||||
HandleCheckError(msg.c_str());
|
||||
}
|
||||
}
|
||||
#endif // RABIT_STRICT_CXX98_
|
||||
|
||||
/*! \brief replace fopen, report error when the file open fails */
|
||||
inline std::FILE *FopenCheck(const char *fname, const char *flag) {
|
||||
@@ -180,6 +162,19 @@ inline std::FILE *FopenCheck(const char *fname, const char *flag) {
|
||||
return fp;
|
||||
}
|
||||
} // namespace utils
|
||||
|
||||
// Can not use std::min on Windows with msvc due to:
|
||||
// error C2589: '(': illegal token on right side of '::'
|
||||
template <typename T>
|
||||
auto Min(T const& l, T const& r) {
|
||||
return l < r ? l : r;
|
||||
}
|
||||
// same with Min
|
||||
template <typename T>
|
||||
auto Max(T const& l, T const& r) {
|
||||
return l > r ? l : r;
|
||||
}
|
||||
|
||||
// easy utils that can be directly accessed in xgboost
|
||||
/*! \brief get the beginning address of a vector */
|
||||
template<typename T>
|
||||
|
||||
Reference in New Issue
Block a user