Enable building rabit on Windows (#6105)

This commit is contained in:
Jiaming Yuan
2020-09-11 11:54:46 +08:00
committed by GitHub
parent 08bdb2efc8
commit c92d751ad1
17 changed files with 215 additions and 352 deletions

View File

@@ -9,10 +9,13 @@
#if defined(_WIN32)
#include <winsock2.h>
#include <ws2tcpip.h>
#ifdef _MSC_VER
#pragma comment(lib, "Ws2_32.lib")
#endif // _MSC_VER
#else
#include <fcntl.h>
#include <netdb.h>
#include <cerrno>
@@ -21,31 +24,92 @@
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/ioctl.h>
#endif // defined(_WIN32)
#include <string>
#include <cstring>
#include <vector>
#include <unordered_map>
#include "utils.h"
#if defined(_WIN32) || defined(__MINGW32__)
#if defined(_WIN32) && !defined(__MINGW32__)
typedef int ssize_t;
#endif // defined(_WIN32) || defined(__MINGW32__)
#if defined(_WIN32)
typedef int sock_size_t;
using sock_size_t = int;
static inline int poll(struct pollfd *pfd, int nfds,
int timeout) { return WSAPoll ( pfd, nfds, timeout ); }
#else
#include <sys/poll.h>
using SOCKET = int;
using sock_size_t = size_t; // NOLINT
const int kInvalidSocket = -1;
#endif // defined(_WIN32)
#define IS_MINGW() defined(__MINGW32__)
#if IS_MINGW()
inline void MingWError() {
throw dmlc::Error("Distributed training on mingw is not supported.");
}
#endif // IS_MINGW()
#if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
/*
* On later mingw versions poll should be supported (with bugs). See:
* https://stackoverflow.com/a/60623080
*
* But right now the mingw distributed with R 3.6 doesn't support it.
* So we just give a warning and provide dummy implementation to get
* compilation passed. Otherwise we will have to provide a stub for
* RABIT.
*
* Even on mingw version that has these structures and flags defined,
* functions like `send` and `listen` might have unresolved linkage to
* their implementation. So supporting mingw is quite difficult at
* the time of writing.
*/
#pragma message("Distributed training on mingw is not supported.")
typedef struct pollfd {
SOCKET fd;
short events;
short revents;
} WSAPOLLFD, *PWSAPOLLFD, *LPWSAPOLLFD;
// POLLRDNORM | POLLRDBAND
#define POLLIN (0x0100 | 0x0200)
#define POLLPRI 0x0400
// POLLWRNORM
#define POLLOUT 0x0010
inline const char *inet_ntop(int, const void *, char *, size_t) {
MingWError();
return nullptr;
}
#endif // IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
namespace rabit {
namespace utils {
static constexpr int kInvalidSocket = -1;
template <typename PollFD>
int PollImpl(PollFD *pfd, int nfds, int timeout) {
#if defined(_WIN32)
#if IS_MINGW()
MingWError();
return -1;
#else
return WSAPoll(pfd, nfds, timeout);
#endif // IS_MINGW()
#else
return poll(pfd, nfds, timeout);
#endif // IS_MINGW()
}
/*! \brief data structure for network address */
struct SockAddr {
sockaddr_in addr;
@@ -56,7 +120,9 @@ struct SockAddr {
}
inline static std::string GetHostName() {
std::string buf; buf.resize(256);
#if !IS_MINGW()
utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
#endif // IS_MINGW()
return std::string(buf.c_str());
}
/*!
@@ -65,6 +131,7 @@ struct SockAddr {
* \param port the port of address
*/
inline void Set(const char *host, int port) {
#if !IS_MINGW()
addrinfo hints;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET;
@@ -76,6 +143,7 @@ struct SockAddr {
memcpy(&addr, res->ai_addr, res->ai_addrlen);
addr.sin_port = htons(port);
freeaddrinfo(res);
#endif // !IS_MINGW()
}
/*! \brief return port of the address*/
inline int Port() const {
@@ -112,7 +180,14 @@ class Socket {
*/
inline static int GetLastError() {
#ifdef _WIN32
#if IS_MINGW()
MingWError();
return -1;
#else
return WSAGetLastError();
#endif // IS_MINGW()
#else
return errno;
#endif // _WIN32
@@ -132,14 +207,16 @@ class Socket {
*/
inline static void Startup() {
#ifdef _WIN32
#if !IS_MINGW()
WSADATA wsa_data;
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
Socket::Error("Startup");
Socket::Error("Startup");
}
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
WSACleanup();
utils::Error("Could not find a usable version of Winsock.dll\n");
}
#endif // !IS_MINGW()
#endif // _WIN32
}
/*!
@@ -147,7 +224,9 @@ class Socket {
*/
inline static void Finalize() {
#ifdef _WIN32
#if !IS_MINGW()
WSACleanup();
#endif // !IS_MINGW()
#endif // _WIN32
}
/*!
@@ -157,10 +236,12 @@ class Socket {
*/
inline void SetNonBlock(bool non_block) {
#ifdef _WIN32
#if !IS_MINGW()
u_long mode = non_block ? 1 : 0;
if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
Socket::Error("SetNonBlock");
}
#endif // !IS_MINGW()
#else
int flag = fcntl(sockfd, F_GETFL, 0);
if (flag == -1) {
@@ -181,10 +262,12 @@ class Socket {
* \param addr
*/
inline void Bind(const SockAddr &addr) {
#if !IS_MINGW()
if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
sizeof(addr.addr)) == -1) {
Socket::Error("Bind");
}
#endif // !IS_MINGW()
}
/*!
* \brief try bind the socket to host, from start_port to end_port
@@ -194,6 +277,7 @@ class Socket {
*/
inline int TryBindHost(int start_port, int end_port) {
// TODO(tqchen) add prefix check
#if !IS_MINGW()
for (int port = start_port; port < end_port; ++port) {
SockAddr addr("0.0.0.0", port);
if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
@@ -210,17 +294,22 @@ class Socket {
}
#endif // defined(_WIN32)
}
#endif // !IS_MINGW()
return -1;
}
/*! \brief get last error code if any */
inline int GetSockError() const {
int error = 0;
socklen_t len = sizeof(error);
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
reinterpret_cast<char*>(&error), &len) != 0) {
#if !IS_MINGW()
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
reinterpret_cast<char *>(&error), &len) != 0) {
Error("GetSockError");
}
#else
// undefined reference to `_imp__getsockopt@20'
MingWError();
#endif // !IS_MINGW()
return error;
}
/*! \brief check if anything bad happens */
@@ -238,7 +327,9 @@ class Socket {
inline void Close() {
if (sockfd != kInvalidSocket) {
#ifdef _WIN32
#if !IS_MINGW()
closesocket(sockfd);
#endif // !IS_MINGW()
#else
close(sockfd);
#endif
@@ -277,50 +368,64 @@ class TCPSocket : public Socket{
* \param keepalive whether to set the keep alive option on
*/
void SetKeepAlive(bool keepalive) {
#if !IS_MINGW()
int opt = static_cast<int>(keepalive);
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE,
reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
Socket::Error("SetKeepAlive");
}
#endif // !IS_MINGW()
}
inline void SetLinger(int timeout = 0) {
#if !IS_MINGW()
struct linger sl;
sl.l_onoff = 1; /* non-zero value enables linger option in kernel */
sl.l_linger = timeout; /* timeout interval in seconds */
if (setsockopt(sockfd, SOL_SOCKET, SO_LINGER, reinterpret_cast<char*>(&sl), sizeof(sl)) == -1) {
Socket::Error("SO_LINGER");
}
#endif // !IS_MINGW()
}
/*!
* \brief create the socket, call this before using socket
* \param af domain
*/
inline void Create(int af = PF_INET) {
#if !IS_MINGW()
sockfd = socket(PF_INET, SOCK_STREAM, 0);
if (sockfd == kInvalidSocket) {
Socket::Error("Create");
}
#endif // !IS_MINGW()
}
/*!
* \brief perform listen of the socket
* \param backlog backlog parameter
*/
inline void Listen(int backlog = 16) {
#if !IS_MINGW()
listen(sockfd, backlog);
#endif // !IS_MINGW()
}
/*! \brief get a new connection */
TCPSocket Accept() {
#if !IS_MINGW()
SOCKET newfd = accept(sockfd, nullptr, nullptr);
if (newfd == kInvalidSocket) {
Socket::Error("Accept");
}
return TCPSocket(newfd);
#else
return TCPSocket();
#endif // !IS_MINGW()
}
/*!
* \brief decide whether the socket is at OOB mark
* \return 1 if at mark, 0 if not, -1 if an error occured
*/
inline int AtMark() const {
#if !IS_MINGW()
#ifdef _WIN32
unsigned long atmark; // NOLINT(*)
if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
@@ -328,7 +433,12 @@ class TCPSocket : public Socket{
int atmark;
if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1;
#endif // _WIN32
return static_cast<int>(atmark);
#else
return -1;
#endif // !IS_MINGW()
}
/*!
* \brief connect to an address
@@ -336,8 +446,12 @@ class TCPSocket : public Socket{
* \return whether connect is successful
*/
inline bool Connect(const SockAddr &addr) {
#if !IS_MINGW()
return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
sizeof(addr.addr)) == 0;
#else
return false;
#endif // !IS_MINGW()
}
/*!
* \brief send data using the socket
@@ -349,7 +463,11 @@ class TCPSocket : public Socket{
*/
inline ssize_t Send(const void *buf_, size_t len, int flag = 0) {
const char *buf = reinterpret_cast<const char*>(buf_);
#if !IS_MINGW()
return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
#else
return 0;
#endif // !IS_MINGW()
}
/*!
* \brief receive data using the socket
@@ -361,7 +479,11 @@ class TCPSocket : public Socket{
*/
inline ssize_t Recv(void *buf_, size_t len, int flags = 0) {
char *buf = reinterpret_cast<char*>(buf_);
#if !IS_MINGW()
return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
#else
return 0;
#endif // !IS_MINGW()
}
/*!
* \brief peform block write that will attempt to send all data out
@@ -373,6 +495,7 @@ class TCPSocket : public Socket{
inline size_t SendAll(const void *buf_, size_t len) {
const char *buf = reinterpret_cast<const char*>(buf_);
size_t ndone = 0;
#if !IS_MINGW()
while (ndone < len) {
ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0);
if (ret == -1) {
@@ -382,6 +505,7 @@ class TCPSocket : public Socket{
buf += ret;
ndone += ret;
}
#endif // !IS_MINGW()
return ndone;
}
/*!
@@ -394,6 +518,7 @@ class TCPSocket : public Socket{
inline size_t RecvAll(void *buf_, size_t len) {
char *buf = reinterpret_cast<char*>(buf_);
size_t ndone = 0;
#if !IS_MINGW()
while (ndone < len) {
ssize_t ret = recv(sockfd, buf,
static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
@@ -405,6 +530,7 @@ class TCPSocket : public Socket{
buf += ret;
ndone += ret;
}
#endif // !IS_MINGW()
return ndone;
}
/*!
@@ -500,7 +626,7 @@ struct PollHelper {
pollfd pfd;
pfd.fd = fd;
pfd.events = POLLPRI;
return poll(&pfd, 1, timeout);
return PollImpl(&pfd, 1, timeout);
}
/*!
@@ -514,7 +640,7 @@ struct PollHelper {
for (auto kv : fds) {
fdset.push_back(kv.second);
}
int ret = poll(fdset.data(), fdset.size(), timeout);
int ret = PollImpl(fdset.data(), fdset.size(), timeout);
if (ret == -1) {
Socket::Error("Poll");
} else {
@@ -533,4 +659,11 @@ struct PollHelper {
};
} // namespace utils
} // namespace rabit
#if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
#undef POLLIN
#undef POLLPRI
#undef POLLOUT
#endif // IS_MINGW()
#endif // RABIT_INTERNAL_SOCKET_H_

View File

@@ -15,10 +15,7 @@
#include <stdexcept>
#include <vector>
#include "dmlc/io.h"
#ifndef RABIT_STRICT_CXX98_
#include <cstdarg>
#endif // RABIT_STRICT_CXX98_
#if !defined(__GNUC__) || defined(__FreeBSD__)
#define fopen64 std::fopen
@@ -71,7 +68,6 @@ inline bool StringToBool(const char* s) {
return CompareStringsCaseInsensitive(s, "true") == 0 || atoi(s) != 0;
}
#ifndef RABIT_CUSTOMIZE_MSG_
/*!
* \brief handling of Assert error, caused by inappropriate input
* \param msg error message
@@ -89,6 +85,7 @@ inline void HandleCheckError(const char *msg) {
fprintf(stderr, "%s, rabit is configured to keep process running\n", msg);
throw dmlc::Error(msg);
}
inline void HandlePrint(const char *msg) {
printf("%s", msg);
}
@@ -102,22 +99,7 @@ inline void HandleLogInfo(const char *fmt, ...) {
fprintf(stdout, "%s", msg.c_str());
fflush(stdout);
}
#else
#ifndef RABIT_STRICT_CXX98_
// include declarations, some one must implement this
void HandleAssertError(const char *msg);
void HandleCheckError(const char *msg);
void HandlePrint(const char *msg);
#endif // RABIT_STRICT_CXX98_
#endif // RABIT_CUSTOMIZE_MSG_
#ifdef RABIT_STRICT_CXX98_
// these function pointers are to be assigned
extern "C" void (*Printf)(const char *fmt, ...);
extern "C" int (*SPrintf)(char *buf, size_t size, const char *fmt, ...);
extern "C" void (*Assert)(int exp, const char *fmt, ...);
extern "C" void (*Check)(int exp, const char *fmt, ...);
extern "C" void (*Error)(const char *fmt, ...);
#else
/*! \brief printf, prints messages to the console */
inline void Printf(const char *fmt, ...) {
std::string msg(kPrintBuffer, '\0');
@@ -127,6 +109,7 @@ inline void Printf(const char *fmt, ...) {
va_end(args);
HandlePrint(msg.c_str());
}
/*! \brief portable version of snprintf */
inline int SPrintf(char *buf, size_t size, const char *fmt, ...) {
va_list args;
@@ -171,7 +154,6 @@ inline void Error(const char *fmt, ...) {
HandleCheckError(msg.c_str());
}
}
#endif // RABIT_STRICT_CXX98_
/*! \brief replace fopen, report error when the file open fails */
inline std::FILE *FopenCheck(const char *fname, const char *flag) {
@@ -180,6 +162,19 @@ inline std::FILE *FopenCheck(const char *fname, const char *flag) {
return fp;
}
} // namespace utils
// Can not use std::min on Windows with msvc due to:
// error C2589: '(': illegal token on right side of '::'
template <typename T>
auto Min(T const& l, T const& r) {
return l < r ? l : r;
}
// same with Min
template <typename T>
auto Max(T const& l, T const& r) {
return l > r ? l : r;
}
// easy utils that can be directly accessed in xgboost
/*! \brief get the beginning address of a vector */
template<typename T>