Enable building rabit on Windows (#6105)

This commit is contained in:
Jiaming Yuan 2020-09-11 11:54:46 +08:00 committed by GitHub
parent 08bdb2efc8
commit c92d751ad1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 215 additions and 352 deletions

View File

@ -203,9 +203,8 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
config: config:
- {os: windows-2016, r: 'release', compiler: 'msvc', build: 'autotools'}
- {os: windows-2016, r: 'release', compiler: 'msvc', build: 'cmake'}
- {os: windows-2016, r: 'release', compiler: 'mingw', build: 'autotools'} - {os: windows-2016, r: 'release', compiler: 'mingw', build: 'autotools'}
- {os: windows-2016, r: 'release', compiler: 'msvc', build: 'cmake'}
- {os: windows-2016, r: 'release', compiler: 'mingw', build: 'cmake'} - {os: windows-2016, r: 'release', compiler: 'mingw', build: 'cmake'}
env: env:
R_REMOTES_NO_ERRORS_FROM_WARNINGS: true R_REMOTES_NO_ERRORS_FROM_WARNINGS: true

View File

@ -177,7 +177,7 @@ else()
-D_CRT_SECURE_NO_WARNINGS -D_CRT_SECURE_NO_DEPRECATE) -D_CRT_SECURE_NO_WARNINGS -D_CRT_SECURE_NO_DEPRECATE)
endif (MSVC) endif (MSVC)
endif(RABIT_MOCK) endif(RABIT_MOCK)
foreach(lib rabit rabit_base rabit_empty rabit_mock rabit_mock_static) foreach(lib rabit rabit_base rabit_mock rabit_mock_static)
# Explicitly link dmlc to rabit, so that configured header (build_config.h) # Explicitly link dmlc to rabit, so that configured header (build_config.h)
# from dmlc is correctly applied to rabit. # from dmlc is correctly applied to rabit.
if (TARGET ${lib}) if (TARGET ${lib})

View File

@ -52,4 +52,3 @@ AC_SUBST(ENDIAN_FLAG)
AC_SUBST(BACKTRACE_LIB) AC_SUBST(BACKTRACE_LIB)
AC_CONFIG_FILES([src/Makevars]) AC_CONFIG_FILES([src/Makevars])
AC_OUTPUT AC_OUTPUT

View File

@ -8,7 +8,7 @@ CXX_STD = CXX14
XGB_RFLAGS = -DXGBOOST_STRICT_R_MODE=1 -DDMLC_LOG_BEFORE_THROW=0\ XGB_RFLAGS = -DXGBOOST_STRICT_R_MODE=1 -DDMLC_LOG_BEFORE_THROW=0\
-DDMLC_ENABLE_STD_THREAD=$(ENABLE_STD_THREAD) -DDMLC_DISABLE_STDIN=1\ -DDMLC_ENABLE_STD_THREAD=$(ENABLE_STD_THREAD) -DDMLC_DISABLE_STDIN=1\
-DDMLC_LOG_CUSTOMIZE=1 -DXGBOOST_CUSTOMIZE_LOGGER=1\ -DDMLC_LOG_CUSTOMIZE=1 -DXGBOOST_CUSTOMIZE_LOGGER=1\
-DRABIT_CUSTOMIZE_MSG_ -DRABIT_STRICT_CXX98_ -DRABIT_CUSTOMIZE_MSG_
# disable the use of thread_local for 32 bit windows: # disable the use of thread_local for 32 bit windows:
ifeq ($(R_OSTYPE)$(WIN),windows) ifeq ($(R_OSTYPE)$(WIN),windows)
@ -21,4 +21,5 @@ PKG_CXXFLAGS= @OPENMP_CXXFLAGS@ @ENDIAN_FLAG@ -pthread
PKG_LIBS = @OPENMP_CXXFLAGS@ @OPENMP_LIB@ @ENDIAN_FLAG@ @BACKTRACE_LIB@ -pthread PKG_LIBS = @OPENMP_CXXFLAGS@ @OPENMP_LIB@ @ENDIAN_FLAG@ @BACKTRACE_LIB@ -pthread
OBJECTS= ./xgboost_R.o ./xgboost_custom.o ./xgboost_assert.o ./init.o \ OBJECTS= ./xgboost_R.o ./xgboost_custom.o ./xgboost_assert.o ./init.o \
$(PKGROOT)/amalgamation/xgboost-all0.o $(PKGROOT)/amalgamation/dmlc-minimum0.o \ $(PKGROOT)/amalgamation/xgboost-all0.o $(PKGROOT)/amalgamation/dmlc-minimum0.o \
$(PKGROOT)/rabit/src/engine_empty.o $(PKGROOT)/rabit/src/c_api.o $(PKGROOT)/rabit/src/engine.o $(PKGROOT)/rabit/src/c_api.o \
$(PKGROOT)/rabit/src/allreduce_base.o $(PKGROOT)/rabit/src/allreduce_robust.o

View File

@ -20,7 +20,7 @@ CXX_STD = CXX14
XGB_RFLAGS = -DXGBOOST_STRICT_R_MODE=1 -DDMLC_LOG_BEFORE_THROW=0\ XGB_RFLAGS = -DXGBOOST_STRICT_R_MODE=1 -DDMLC_LOG_BEFORE_THROW=0\
-DDMLC_ENABLE_STD_THREAD=$(ENABLE_STD_THREAD) -DDMLC_DISABLE_STDIN=1\ -DDMLC_ENABLE_STD_THREAD=$(ENABLE_STD_THREAD) -DDMLC_DISABLE_STDIN=1\
-DDMLC_LOG_CUSTOMIZE=1 -DXGBOOST_CUSTOMIZE_LOGGER=1\ -DDMLC_LOG_CUSTOMIZE=1 -DXGBOOST_CUSTOMIZE_LOGGER=1\
-DRABIT_CUSTOMIZE_MSG_ -DRABIT_STRICT_CXX98_ -DRABIT_CUSTOMIZE_MSG_
# disable the use of thread_local for 32 bit windows: # disable the use of thread_local for 32 bit windows:
ifeq ($(R_OSTYPE)$(WIN),windows) ifeq ($(R_OSTYPE)$(WIN),windows)
@ -33,6 +33,7 @@ PKG_CXXFLAGS= $(SHLIB_OPENMP_CXXFLAGS) $(SHLIB_PTHREAD_FLAGS)
PKG_LIBS = $(SHLIB_OPENMP_CXXFLAGS) $(SHLIB_PTHREAD_FLAGS) PKG_LIBS = $(SHLIB_OPENMP_CXXFLAGS) $(SHLIB_PTHREAD_FLAGS)
OBJECTS= ./xgboost_R.o ./xgboost_custom.o ./xgboost_assert.o ./init.o \ OBJECTS= ./xgboost_R.o ./xgboost_custom.o ./xgboost_assert.o ./init.o \
$(PKGROOT)/amalgamation/xgboost-all0.o $(PKGROOT)/amalgamation/dmlc-minimum0.o \ $(PKGROOT)/amalgamation/xgboost-all0.o $(PKGROOT)/amalgamation/dmlc-minimum0.o \
$(PKGROOT)/rabit/src/engine_empty.o $(PKGROOT)/rabit/src/c_api.o $(PKGROOT)/rabit/src/engine.o $(PKGROOT)/rabit/src/c_api.o \
$(PKGROOT)/rabit/src/allreduce_base.o $(PKGROOT)/rabit/src/allreduce_robust.o
$(OBJECTS) : xgblib $(OBJECTS) : xgblib

View File

@ -13,23 +13,6 @@ void CustomLogMessage::Log(const std::string& msg) {
} }
} // namespace dmlc } // namespace dmlc
// implements rabit error handling.
extern "C" {
void XGBoostAssert_R(int exp, const char *fmt, ...);
void XGBoostCheck_R(int exp, const char *fmt, ...);
}
namespace rabit {
namespace utils {
extern "C" {
void (*Printf)(const char *fmt, ...) = Rprintf;
void (*Assert)(int exp, const char *fmt, ...) = XGBoostAssert_R;
void (*Check)(int exp, const char *fmt, ...) = XGBoostCheck_R;
void (*Error)(const char *fmt, ...) = error;
}
}
}
namespace xgboost { namespace xgboost {
ConsoleLogger::~ConsoleLogger() { ConsoleLogger::~ConsoleLogger() {
if (cur_verbosity_ == LogVerbosity::kIgnore || if (cur_verbosity_ == LogVerbosity::kIgnore ||

View File

@ -12,5 +12,3 @@
#include "../dmlc-core/src/data.cc" #include "../dmlc-core/src/data.cc"
#include "../dmlc-core/src/io.cc" #include "../dmlc-core/src/io.cc"
#include "../dmlc-core/src/recordio.cc" #include "../dmlc-core/src/recordio.cc"

View File

@ -1,13 +1,5 @@
cmake_minimum_required(VERSION 3.3) cmake_minimum_required(VERSION 3.3)
if(R_LIB OR MINGW OR WIN32)
add_library(rabit src/engine_empty.cc src/c_api.cc)
set(rabit_libs rabit)
set_target_properties(rabit
PROPERTIES CXX_STANDARD 14
CXX_STANDARD_REQUIRED ON
POSITION_INDEPENDENT_CODE ON)
else()
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
add_library(rabit src/allreduce_base.cc src/allreduce_robust.cc src/engine.cc src/c_api.cc) add_library(rabit src/allreduce_base.cc src/allreduce_robust.cc src/engine.cc src/c_api.cc)
@ -22,7 +14,6 @@ else()
PROPERTIES CXX_STANDARD 14 PROPERTIES CXX_STANDARD 14
CXX_STANDARD_REQUIRED ON CXX_STANDARD_REQUIRED ON
POSITION_INDEPENDENT_CODE ON) POSITION_INDEPENDENT_CODE ON)
endif(R_LIB OR MINGW OR WIN32)
if(RABIT_BUILD_MPI) if(RABIT_BUILD_MPI)
find_package(MPI REQUIRED) find_package(MPI REQUIRED)

View File

@ -1,104 +0,0 @@
OS := $(shell uname)
RABIT_BUILD_DMLC = 0
export WARNFLAGS= -Wall -Wextra -Wno-unused-parameter -Wno-unknown-pragmas -std=c++11
export CFLAGS = -O3 $(WARNFLAGS)
export LDFLAGS =-Llib
#download mpi
#echo $(shell scripts/mpi.sh)
MPICXX=./mpich/bin/mpicxx
export CXX = g++
#----------------------------
# Settings for power and arm arch
#----------------------------
ARCH := $(shell uname -a)
ifneq (,$(filter $(ARCH), armv6l armv7l powerpc64le ppc64le aarch64))
CFLAGS += -march=native
else
CFLAGS += -msse2
endif
ifndef WITH_FPIC
WITH_FPIC = 1
endif
ifeq ($(WITH_FPIC), 1)
CFLAGS += -fPIC
endif
ifndef LINT_LANG
LINT_LANG="all"
endif
ifeq ($(RABIT_BUILD_DMLC),1)
DMLC=dmlc-core
else
DMLC=../dmlc-core
endif
CFLAGS += -I $(DMLC)/include -I include/
# build path
BPATH=.
# objectives that makes up rabit library
MPIOBJ= $(BPATH)/engine_mpi.o
OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(BPATH)/engine_empty.o $(BPATH)/engine_mock.o\
$(BPATH)/c_api.o $(BPATH)/engine_base.o
SLIB= lib/librabit.so lib/librabit_mock.so lib/librabit_base.so
ALIB= lib/librabit.a lib/librabit_empty.a lib/librabit_mock.a lib/librabit_base.a
MPISLIB= lib/librabit_mpi.so
MPIALIB= lib/librabit_mpi.a
HEADERS=src/*.h include/rabit/*.h include/rabit/internal/*.h
.PHONY: clean all install mpi python lint doc doxygen
all: lib/librabit.a lib/librabit_mock.a lib/librabit.so lib/librabit_base.a lib/librabit_mock.so
mpi: lib/librabit_mpi.a lib/librabit_mpi.so
$(BPATH)/allreduce_base.o: src/allreduce_base.cc $(HEADERS)
$(BPATH)/engine.o: src/engine.cc $(HEADERS)
$(BPATH)/allreduce_robust.o: src/allreduce_robust.cc $(HEADERS)
$(BPATH)/engine_mpi.o: src/engine_mpi.cc $(HEADERS)
$(BPATH)/engine_empty.o: src/engine_empty.cc $(HEADERS)
$(BPATH)/engine_mock.o: src/engine_mock.cc $(HEADERS)
$(BPATH)/engine_base.o: src/engine_base.cc $(HEADERS)
$(BPATH)/c_api.o: src/c_api.cc $(HEADERS)
lib/librabit.a lib/librabit.so: $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(BPATH)/c_api.o
lib/librabit_base.a lib/librabit_base.so: $(BPATH)/allreduce_base.o $(BPATH)/engine_base.o $(BPATH)/c_api.o
lib/librabit_mock.a lib/librabit_mock.so: $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine_mock.o $(BPATH)/c_api.o
lib/librabit_empty.a: $(BPATH)/engine_empty.o $(BPATH)/c_api.o
lib/librabit_mpi.a lib/librabit_mpi.so: $(MPIOBJ)
$(OBJ) :
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
$(ALIB):
ar cr $@ $+
$(SLIB) :
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS)
$(MPIOBJ) :
$(MPICXX) -c $(CFLAGS) -I./mpich/include -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
$(MPIALIB):
ar cr $@ $+
$(MPISLIB) :
$(MPICXX) $(CFLAGS) -I./mpich/include -shared -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) \
$(LDFLAGS) -L./mpich/lib -Wl,-rpath,./mpich/lib -lmpi
lint:
$(DMLC)/scripts/lint.py rabit $(LINT_LANG) src include
doc doxygen:
cd include; doxygen ../doc/Doxyfile; cd -
clean:
$(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) $(SLIB) *~ src/*~ include/*~ include/*/*~

View File

@ -9,10 +9,13 @@
#if defined(_WIN32) #if defined(_WIN32)
#include <winsock2.h> #include <winsock2.h>
#include <ws2tcpip.h> #include <ws2tcpip.h>
#ifdef _MSC_VER #ifdef _MSC_VER
#pragma comment(lib, "Ws2_32.lib") #pragma comment(lib, "Ws2_32.lib")
#endif // _MSC_VER #endif // _MSC_VER
#else #else
#include <fcntl.h> #include <fcntl.h>
#include <netdb.h> #include <netdb.h>
#include <cerrno> #include <cerrno>
@ -21,31 +24,92 @@
#include <netinet/in.h> #include <netinet/in.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <sys/ioctl.h> #include <sys/ioctl.h>
#endif // defined(_WIN32) #endif // defined(_WIN32)
#include <string> #include <string>
#include <cstring> #include <cstring>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include "utils.h" #include "utils.h"
#if defined(_WIN32) || defined(__MINGW32__) #if defined(_WIN32) && !defined(__MINGW32__)
typedef int ssize_t; typedef int ssize_t;
#endif // defined(_WIN32) || defined(__MINGW32__) #endif // defined(_WIN32) || defined(__MINGW32__)
#if defined(_WIN32) #if defined(_WIN32)
typedef int sock_size_t; using sock_size_t = int;
static inline int poll(struct pollfd *pfd, int nfds,
int timeout) { return WSAPoll ( pfd, nfds, timeout ); }
#else #else
#include <sys/poll.h> #include <sys/poll.h>
using SOCKET = int; using SOCKET = int;
using sock_size_t = size_t; // NOLINT using sock_size_t = size_t; // NOLINT
const int kInvalidSocket = -1;
#endif // defined(_WIN32) #endif // defined(_WIN32)
#define IS_MINGW() defined(__MINGW32__)
#if IS_MINGW()
inline void MingWError() {
throw dmlc::Error("Distributed training on mingw is not supported.");
}
#endif // IS_MINGW()
#if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
/*
* On later mingw versions poll should be supported (with bugs). See:
* https://stackoverflow.com/a/60623080
*
* But right now the mingw distributed with R 3.6 doesn't support it.
* So we just give a warning and provide dummy implementation to get
* compilation passed. Otherwise we will have to provide a stub for
* RABIT.
*
* Even on mingw version that has these structures and flags defined,
* functions like `send` and `listen` might have unresolved linkage to
* their implementation. So supporting mingw is quite difficult at
* the time of writing.
*/
#pragma message("Distributed training on mingw is not supported.")
typedef struct pollfd {
SOCKET fd;
short events;
short revents;
} WSAPOLLFD, *PWSAPOLLFD, *LPWSAPOLLFD;
// POLLRDNORM | POLLRDBAND
#define POLLIN (0x0100 | 0x0200)
#define POLLPRI 0x0400
// POLLWRNORM
#define POLLOUT 0x0010
inline const char *inet_ntop(int, const void *, char *, size_t) {
MingWError();
return nullptr;
}
#endif // IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
namespace rabit { namespace rabit {
namespace utils { namespace utils {
static constexpr int kInvalidSocket = -1;
template <typename PollFD>
int PollImpl(PollFD *pfd, int nfds, int timeout) {
#if defined(_WIN32)
#if IS_MINGW()
MingWError();
return -1;
#else
return WSAPoll(pfd, nfds, timeout);
#endif // IS_MINGW()
#else
return poll(pfd, nfds, timeout);
#endif // IS_MINGW()
}
/*! \brief data structure for network address */ /*! \brief data structure for network address */
struct SockAddr { struct SockAddr {
sockaddr_in addr; sockaddr_in addr;
@ -56,7 +120,9 @@ struct SockAddr {
} }
inline static std::string GetHostName() { inline static std::string GetHostName() {
std::string buf; buf.resize(256); std::string buf; buf.resize(256);
#if !IS_MINGW()
utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name"); utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
#endif // IS_MINGW()
return std::string(buf.c_str()); return std::string(buf.c_str());
} }
/*! /*!
@ -65,6 +131,7 @@ struct SockAddr {
* \param port the port of address * \param port the port of address
*/ */
inline void Set(const char *host, int port) { inline void Set(const char *host, int port) {
#if !IS_MINGW()
addrinfo hints; addrinfo hints;
memset(&hints, 0, sizeof(hints)); memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET; hints.ai_family = AF_INET;
@ -76,6 +143,7 @@ struct SockAddr {
memcpy(&addr, res->ai_addr, res->ai_addrlen); memcpy(&addr, res->ai_addr, res->ai_addrlen);
addr.sin_port = htons(port); addr.sin_port = htons(port);
freeaddrinfo(res); freeaddrinfo(res);
#endif // !IS_MINGW()
} }
/*! \brief return port of the address*/ /*! \brief return port of the address*/
inline int Port() const { inline int Port() const {
@ -112,7 +180,14 @@ class Socket {
*/ */
inline static int GetLastError() { inline static int GetLastError() {
#ifdef _WIN32 #ifdef _WIN32
#if IS_MINGW()
MingWError();
return -1;
#else
return WSAGetLastError(); return WSAGetLastError();
#endif // IS_MINGW()
#else #else
return errno; return errno;
#endif // _WIN32 #endif // _WIN32
@ -132,6 +207,7 @@ class Socket {
*/ */
inline static void Startup() { inline static void Startup() {
#ifdef _WIN32 #ifdef _WIN32
#if !IS_MINGW()
WSADATA wsa_data; WSADATA wsa_data;
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) { if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
Socket::Error("Startup"); Socket::Error("Startup");
@ -140,6 +216,7 @@ class Socket {
WSACleanup(); WSACleanup();
utils::Error("Could not find a usable version of Winsock.dll\n"); utils::Error("Could not find a usable version of Winsock.dll\n");
} }
#endif // !IS_MINGW()
#endif // _WIN32 #endif // _WIN32
} }
/*! /*!
@ -147,7 +224,9 @@ class Socket {
*/ */
inline static void Finalize() { inline static void Finalize() {
#ifdef _WIN32 #ifdef _WIN32
#if !IS_MINGW()
WSACleanup(); WSACleanup();
#endif // !IS_MINGW()
#endif // _WIN32 #endif // _WIN32
} }
/*! /*!
@ -157,10 +236,12 @@ class Socket {
*/ */
inline void SetNonBlock(bool non_block) { inline void SetNonBlock(bool non_block) {
#ifdef _WIN32 #ifdef _WIN32
#if !IS_MINGW()
u_long mode = non_block ? 1 : 0; u_long mode = non_block ? 1 : 0;
if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) { if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
Socket::Error("SetNonBlock"); Socket::Error("SetNonBlock");
} }
#endif // !IS_MINGW()
#else #else
int flag = fcntl(sockfd, F_GETFL, 0); int flag = fcntl(sockfd, F_GETFL, 0);
if (flag == -1) { if (flag == -1) {
@ -181,10 +262,12 @@ class Socket {
* \param addr * \param addr
*/ */
inline void Bind(const SockAddr &addr) { inline void Bind(const SockAddr &addr) {
#if !IS_MINGW()
if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr), if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
sizeof(addr.addr)) == -1) { sizeof(addr.addr)) == -1) {
Socket::Error("Bind"); Socket::Error("Bind");
} }
#endif // !IS_MINGW()
} }
/*! /*!
* \brief try bind the socket to host, from start_port to end_port * \brief try bind the socket to host, from start_port to end_port
@ -194,6 +277,7 @@ class Socket {
*/ */
inline int TryBindHost(int start_port, int end_port) { inline int TryBindHost(int start_port, int end_port) {
// TODO(tqchen) add prefix check // TODO(tqchen) add prefix check
#if !IS_MINGW()
for (int port = start_port; port < end_port; ++port) { for (int port = start_port; port < end_port; ++port) {
SockAddr addr("0.0.0.0", port); SockAddr addr("0.0.0.0", port);
if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr), if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
@ -210,17 +294,22 @@ class Socket {
} }
#endif // defined(_WIN32) #endif // defined(_WIN32)
} }
#endif // !IS_MINGW()
return -1; return -1;
} }
/*! \brief get last error code if any */ /*! \brief get last error code if any */
inline int GetSockError() const { inline int GetSockError() const {
int error = 0; int error = 0;
socklen_t len = sizeof(error); socklen_t len = sizeof(error);
#if !IS_MINGW()
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
reinterpret_cast<char *>(&error), &len) != 0) { reinterpret_cast<char *>(&error), &len) != 0) {
Error("GetSockError"); Error("GetSockError");
} }
#else
// undefined reference to `_imp__getsockopt@20'
MingWError();
#endif // !IS_MINGW()
return error; return error;
} }
/*! \brief check if anything bad happens */ /*! \brief check if anything bad happens */
@ -238,7 +327,9 @@ class Socket {
inline void Close() { inline void Close() {
if (sockfd != kInvalidSocket) { if (sockfd != kInvalidSocket) {
#ifdef _WIN32 #ifdef _WIN32
#if !IS_MINGW()
closesocket(sockfd); closesocket(sockfd);
#endif // !IS_MINGW()
#else #else
close(sockfd); close(sockfd);
#endif #endif
@ -277,50 +368,64 @@ class TCPSocket : public Socket{
* \param keepalive whether to set the keep alive option on * \param keepalive whether to set the keep alive option on
*/ */
void SetKeepAlive(bool keepalive) { void SetKeepAlive(bool keepalive) {
#if !IS_MINGW()
int opt = static_cast<int>(keepalive); int opt = static_cast<int>(keepalive);
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE,
reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) { reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
Socket::Error("SetKeepAlive"); Socket::Error("SetKeepAlive");
} }
#endif // !IS_MINGW()
} }
inline void SetLinger(int timeout = 0) { inline void SetLinger(int timeout = 0) {
#if !IS_MINGW()
struct linger sl; struct linger sl;
sl.l_onoff = 1; /* non-zero value enables linger option in kernel */ sl.l_onoff = 1; /* non-zero value enables linger option in kernel */
sl.l_linger = timeout; /* timeout interval in seconds */ sl.l_linger = timeout; /* timeout interval in seconds */
if (setsockopt(sockfd, SOL_SOCKET, SO_LINGER, reinterpret_cast<char*>(&sl), sizeof(sl)) == -1) { if (setsockopt(sockfd, SOL_SOCKET, SO_LINGER, reinterpret_cast<char*>(&sl), sizeof(sl)) == -1) {
Socket::Error("SO_LINGER"); Socket::Error("SO_LINGER");
} }
#endif // !IS_MINGW()
} }
/*! /*!
* \brief create the socket, call this before using socket * \brief create the socket, call this before using socket
* \param af domain * \param af domain
*/ */
inline void Create(int af = PF_INET) { inline void Create(int af = PF_INET) {
#if !IS_MINGW()
sockfd = socket(PF_INET, SOCK_STREAM, 0); sockfd = socket(PF_INET, SOCK_STREAM, 0);
if (sockfd == kInvalidSocket) { if (sockfd == kInvalidSocket) {
Socket::Error("Create"); Socket::Error("Create");
} }
#endif // !IS_MINGW()
} }
/*! /*!
* \brief perform listen of the socket * \brief perform listen of the socket
* \param backlog backlog parameter * \param backlog backlog parameter
*/ */
inline void Listen(int backlog = 16) { inline void Listen(int backlog = 16) {
#if !IS_MINGW()
listen(sockfd, backlog); listen(sockfd, backlog);
#endif // !IS_MINGW()
} }
/*! \brief get a new connection */ /*! \brief get a new connection */
TCPSocket Accept() { TCPSocket Accept() {
#if !IS_MINGW()
SOCKET newfd = accept(sockfd, nullptr, nullptr); SOCKET newfd = accept(sockfd, nullptr, nullptr);
if (newfd == kInvalidSocket) { if (newfd == kInvalidSocket) {
Socket::Error("Accept"); Socket::Error("Accept");
} }
return TCPSocket(newfd); return TCPSocket(newfd);
#else
return TCPSocket();
#endif // !IS_MINGW()
} }
/*! /*!
* \brief decide whether the socket is at OOB mark * \brief decide whether the socket is at OOB mark
* \return 1 if at mark, 0 if not, -1 if an error occured * \return 1 if at mark, 0 if not, -1 if an error occured
*/ */
inline int AtMark() const { inline int AtMark() const {
#if !IS_MINGW()
#ifdef _WIN32 #ifdef _WIN32
unsigned long atmark; // NOLINT(*) unsigned long atmark; // NOLINT(*)
if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1; if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
@ -328,7 +433,12 @@ class TCPSocket : public Socket{
int atmark; int atmark;
if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1; if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1;
#endif // _WIN32 #endif // _WIN32
return static_cast<int>(atmark); return static_cast<int>(atmark);
#else
return -1;
#endif // !IS_MINGW()
} }
/*! /*!
* \brief connect to an address * \brief connect to an address
@ -336,8 +446,12 @@ class TCPSocket : public Socket{
* \return whether connect is successful * \return whether connect is successful
*/ */
inline bool Connect(const SockAddr &addr) { inline bool Connect(const SockAddr &addr) {
#if !IS_MINGW()
return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr), return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
sizeof(addr.addr)) == 0; sizeof(addr.addr)) == 0;
#else
return false;
#endif // !IS_MINGW()
} }
/*! /*!
* \brief send data using the socket * \brief send data using the socket
@ -349,7 +463,11 @@ class TCPSocket : public Socket{
*/ */
inline ssize_t Send(const void *buf_, size_t len, int flag = 0) { inline ssize_t Send(const void *buf_, size_t len, int flag = 0) {
const char *buf = reinterpret_cast<const char*>(buf_); const char *buf = reinterpret_cast<const char*>(buf_);
#if !IS_MINGW()
return send(sockfd, buf, static_cast<sock_size_t>(len), flag); return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
#else
return 0;
#endif // !IS_MINGW()
} }
/*! /*!
* \brief receive data using the socket * \brief receive data using the socket
@ -361,7 +479,11 @@ class TCPSocket : public Socket{
*/ */
inline ssize_t Recv(void *buf_, size_t len, int flags = 0) { inline ssize_t Recv(void *buf_, size_t len, int flags = 0) {
char *buf = reinterpret_cast<char*>(buf_); char *buf = reinterpret_cast<char*>(buf_);
#if !IS_MINGW()
return recv(sockfd, buf, static_cast<sock_size_t>(len), flags); return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
#else
return 0;
#endif // !IS_MINGW()
} }
/*! /*!
* \brief peform block write that will attempt to send all data out * \brief peform block write that will attempt to send all data out
@ -373,6 +495,7 @@ class TCPSocket : public Socket{
inline size_t SendAll(const void *buf_, size_t len) { inline size_t SendAll(const void *buf_, size_t len) {
const char *buf = reinterpret_cast<const char*>(buf_); const char *buf = reinterpret_cast<const char*>(buf_);
size_t ndone = 0; size_t ndone = 0;
#if !IS_MINGW()
while (ndone < len) { while (ndone < len) {
ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0); ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0);
if (ret == -1) { if (ret == -1) {
@ -382,6 +505,7 @@ class TCPSocket : public Socket{
buf += ret; buf += ret;
ndone += ret; ndone += ret;
} }
#endif // !IS_MINGW()
return ndone; return ndone;
} }
/*! /*!
@ -394,6 +518,7 @@ class TCPSocket : public Socket{
inline size_t RecvAll(void *buf_, size_t len) { inline size_t RecvAll(void *buf_, size_t len) {
char *buf = reinterpret_cast<char*>(buf_); char *buf = reinterpret_cast<char*>(buf_);
size_t ndone = 0; size_t ndone = 0;
#if !IS_MINGW()
while (ndone < len) { while (ndone < len) {
ssize_t ret = recv(sockfd, buf, ssize_t ret = recv(sockfd, buf,
static_cast<sock_size_t>(len - ndone), MSG_WAITALL); static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
@ -405,6 +530,7 @@ class TCPSocket : public Socket{
buf += ret; buf += ret;
ndone += ret; ndone += ret;
} }
#endif // !IS_MINGW()
return ndone; return ndone;
} }
/*! /*!
@ -500,7 +626,7 @@ struct PollHelper {
pollfd pfd; pollfd pfd;
pfd.fd = fd; pfd.fd = fd;
pfd.events = POLLPRI; pfd.events = POLLPRI;
return poll(&pfd, 1, timeout); return PollImpl(&pfd, 1, timeout);
} }
/*! /*!
@ -514,7 +640,7 @@ struct PollHelper {
for (auto kv : fds) { for (auto kv : fds) {
fdset.push_back(kv.second); fdset.push_back(kv.second);
} }
int ret = poll(fdset.data(), fdset.size(), timeout); int ret = PollImpl(fdset.data(), fdset.size(), timeout);
if (ret == -1) { if (ret == -1) {
Socket::Error("Poll"); Socket::Error("Poll");
} else { } else {
@ -533,4 +659,11 @@ struct PollHelper {
}; };
} // namespace utils } // namespace utils
} // namespace rabit } // namespace rabit
#if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
#undef POLLIN
#undef POLLPRI
#undef POLLOUT
#endif // IS_MINGW()
#endif // RABIT_INTERNAL_SOCKET_H_ #endif // RABIT_INTERNAL_SOCKET_H_

View File

@ -15,10 +15,7 @@
#include <stdexcept> #include <stdexcept>
#include <vector> #include <vector>
#include "dmlc/io.h" #include "dmlc/io.h"
#ifndef RABIT_STRICT_CXX98_
#include <cstdarg> #include <cstdarg>
#endif // RABIT_STRICT_CXX98_
#if !defined(__GNUC__) || defined(__FreeBSD__) #if !defined(__GNUC__) || defined(__FreeBSD__)
#define fopen64 std::fopen #define fopen64 std::fopen
@ -71,7 +68,6 @@ inline bool StringToBool(const char* s) {
return CompareStringsCaseInsensitive(s, "true") == 0 || atoi(s) != 0; return CompareStringsCaseInsensitive(s, "true") == 0 || atoi(s) != 0;
} }
#ifndef RABIT_CUSTOMIZE_MSG_
/*! /*!
* \brief handling of Assert error, caused by inappropriate input * \brief handling of Assert error, caused by inappropriate input
* \param msg error message * \param msg error message
@ -89,6 +85,7 @@ inline void HandleCheckError(const char *msg) {
fprintf(stderr, "%s, rabit is configured to keep process running\n", msg); fprintf(stderr, "%s, rabit is configured to keep process running\n", msg);
throw dmlc::Error(msg); throw dmlc::Error(msg);
} }
inline void HandlePrint(const char *msg) { inline void HandlePrint(const char *msg) {
printf("%s", msg); printf("%s", msg);
} }
@ -102,22 +99,7 @@ inline void HandleLogInfo(const char *fmt, ...) {
fprintf(stdout, "%s", msg.c_str()); fprintf(stdout, "%s", msg.c_str());
fflush(stdout); fflush(stdout);
} }
#else
#ifndef RABIT_STRICT_CXX98_
// include declarations, some one must implement this
void HandleAssertError(const char *msg);
void HandleCheckError(const char *msg);
void HandlePrint(const char *msg);
#endif // RABIT_STRICT_CXX98_
#endif // RABIT_CUSTOMIZE_MSG_
#ifdef RABIT_STRICT_CXX98_
// these function pointers are to be assigned
extern "C" void (*Printf)(const char *fmt, ...);
extern "C" int (*SPrintf)(char *buf, size_t size, const char *fmt, ...);
extern "C" void (*Assert)(int exp, const char *fmt, ...);
extern "C" void (*Check)(int exp, const char *fmt, ...);
extern "C" void (*Error)(const char *fmt, ...);
#else
/*! \brief printf, prints messages to the console */ /*! \brief printf, prints messages to the console */
inline void Printf(const char *fmt, ...) { inline void Printf(const char *fmt, ...) {
std::string msg(kPrintBuffer, '\0'); std::string msg(kPrintBuffer, '\0');
@ -127,6 +109,7 @@ inline void Printf(const char *fmt, ...) {
va_end(args); va_end(args);
HandlePrint(msg.c_str()); HandlePrint(msg.c_str());
} }
/*! \brief portable version of snprintf */ /*! \brief portable version of snprintf */
inline int SPrintf(char *buf, size_t size, const char *fmt, ...) { inline int SPrintf(char *buf, size_t size, const char *fmt, ...) {
va_list args; va_list args;
@ -171,7 +154,6 @@ inline void Error(const char *fmt, ...) {
HandleCheckError(msg.c_str()); HandleCheckError(msg.c_str());
} }
} }
#endif // RABIT_STRICT_CXX98_
/*! \brief replace fopen, report error when the file open fails */ /*! \brief replace fopen, report error when the file open fails */
inline std::FILE *FopenCheck(const char *fname, const char *flag) { inline std::FILE *FopenCheck(const char *fname, const char *flag) {
@ -180,6 +162,19 @@ inline std::FILE *FopenCheck(const char *fname, const char *flag) {
return fp; return fp;
} }
} // namespace utils } // namespace utils
// Can not use std::min on Windows with msvc due to:
// error C2589: '(': illegal token on right side of '::'
template <typename T>
auto Min(T const& l, T const& r) {
return l < r ? l : r;
}
// same with Min
template <typename T>
auto Max(T const& l, T const& r) {
return l > r ? l : r;
}
// easy utils that can be directly accessed in xgboost // easy utils that can be directly accessed in xgboost
/*! \brief get the beginning address of a vector */ /*! \brief get the beginning address of a vector */
template<typename T> template<typename T>

View File

@ -8,7 +8,11 @@
#define NOMINMAX #define NOMINMAX
#include "allreduce_base.h" #include "allreduce_base.h"
#include <rabit/base.h> #include <rabit/base.h>
#ifndef _WIN32
#include <netinet/tcp.h> #include <netinet/tcp.h>
#endif // _WIN32
#include <cstring> #include <cstring>
#include <map> #include <map>
@ -413,8 +417,12 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
all_link.sock.SetNonBlock(true); all_link.sock.SetNonBlock(true);
all_link.sock.SetKeepAlive(true); all_link.sock.SetKeepAlive(true);
if (rabit_enable_tcp_no_delay) { if (rabit_enable_tcp_no_delay) {
#if defined(__unix__)
setsockopt(all_link.sock, IPPROTO_TCP, setsockopt(all_link.sock, IPPROTO_TCP,
TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay)); TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay));
#else
fprintf(stderr, "tcp no delay is not implemented on non unix platforms\n");
#endif
} }
if (tree_neighbors.count(all_link.rank) != 0) { if (tree_neighbors.count(all_link.rank) != 0) {
if (all_link.rank == parent_rank) { if (all_link.rank == parent_rank) {

View File

@ -306,10 +306,11 @@ class AllreduceBase : public IEngine {
// constructor // constructor
LinkRecord() = default; LinkRecord() = default;
// initialize buffer // initialize buffer
inline void InitBuffer(size_t type_nbytes, size_t count, void InitBuffer(size_t type_nbytes, size_t count,
size_t reduce_buffer_size) { size_t reduce_buffer_size) {
size_t n = (type_nbytes * count + 7)/ 8; size_t n = (type_nbytes * count + 7)/ 8;
buffer_.resize(std::min(reduce_buffer_size, n)); auto to = Min(reduce_buffer_size, n);
buffer_.resize(to);
// make sure align to type_nbytes // make sure align to type_nbytes
buffer_size = buffer_size =
buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes; buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
@ -338,8 +339,8 @@ class AllreduceBase : public IEngine {
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check"); utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
size_t offset = size_read % buffer_size; size_t offset = size_read % buffer_size;
size_t nmax = max_size_read - size_read; size_t nmax = max_size_read - size_read;
nmax = std::min(nmax, buffer_size - ngap); nmax = Min(nmax, buffer_size - ngap);
nmax = std::min(nmax, buffer_size - offset); nmax = Min(nmax, buffer_size - offset);
if (nmax == 0) return kSuccess; if (nmax == 0) return kSuccess;
ssize_t len = sock.Recv(buffer_head + offset, nmax); ssize_t len = sock.Recv(buffer_head + offset, nmax);
// length equals 0, remote disconnected // length equals 0, remote disconnected

View File

@ -217,11 +217,11 @@ class AllreduceRobust : public AllreduceBase {
*/ */
struct ActionSummary { struct ActionSummary {
// maximumly allowed sequence id // maximumly allowed sequence id
static const u_int32_t kSpecialOp = (1 << 26); static const uint32_t kSpecialOp = (1 << 26);
// special sequence number for local state checkpoint // special sequence number for local state checkpoint
static const u_int32_t kLocalCheckPoint = (1 << 26) - 2; static const uint32_t kLocalCheckPoint = (1 << 26) - 2;
// special sequnce number for local state checkpoint ack signal // special sequnce number for local state checkpoint ack signal
static const u_int32_t kLocalCheckAck = (1 << 26) - 1; static const uint32_t kLocalCheckAck = (1 << 26) - 1;
//--------------------------------------------- //---------------------------------------------
// The following are bit mask of flag used in // The following are bit mask of flag used in
//---------------------------------------------- //----------------------------------------------
@ -242,13 +242,13 @@ class AllreduceRobust : public AllreduceBase {
ActionSummary() = default; ActionSummary() = default;
// constructor of action // constructor of action
explicit ActionSummary(int seqno_flag, int cache_flag = 0, explicit ActionSummary(int seqno_flag, int cache_flag = 0,
u_int32_t minseqno = kSpecialOp, u_int32_t maxseqno = kSpecialOp) { uint32_t minseqno = kSpecialOp, uint32_t maxseqno = kSpecialOp) {
seqcode_ = (minseqno << 5) | seqno_flag; seqcode_ = (minseqno << 5) | seqno_flag;
maxseqcode_ = (maxseqno << 5) | cache_flag; maxseqcode_ = (maxseqno << 5) | cache_flag;
} }
// minimum number of all operations by default // minimum number of all operations by default
// maximum number of all cache operations otherwise // maximum number of all cache operations otherwise
inline u_int32_t Seqno(SeqType t = SeqType::kSeq) const { inline uint32_t Seqno(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_; int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
return code >> 5; return code >> 5;
} }
@ -294,8 +294,8 @@ class AllreduceRobust : public AllreduceBase {
const ActionSummary *src = static_cast<const ActionSummary*>(src_); const ActionSummary *src = static_cast<const ActionSummary*>(src_);
ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_); ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_);
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
u_int32_t min_seqno = std::min(src[i].Seqno(), dst[i].Seqno()); uint32_t min_seqno = Min(src[i].Seqno(), dst[i].Seqno());
u_int32_t max_seqno = std::max(src[i].Seqno(SeqType::kCache), uint32_t max_seqno = Max(src[i].Seqno(SeqType::kCache),
dst[i].Seqno(SeqType::kCache)); dst[i].Seqno(SeqType::kCache));
int action_flag = src[i].Flag() | dst[i].Flag(); int action_flag = src[i].Flag() | dst[i].Flag();
// if any node is not requester set to 0 otherwise 1 // if any node is not requester set to 0 otherwise 1
@ -310,9 +310,9 @@ class AllreduceRobust : public AllreduceBase {
private: private:
// internel sequence code min of rabit seqno // internel sequence code min of rabit seqno
u_int32_t seqcode_; uint32_t seqcode_;
// internal sequence code max of cache seqno // internal sequence code max of cache seqno
u_int32_t maxseqcode_; uint32_t maxseqcode_;
}; };
/*! \brief data structure to remember result of Bcast and Allreduce calls*/ /*! \brief data structure to remember result of Bcast and Allreduce calls*/
class ResultBuffer{ class ResultBuffer{

View File

@ -11,4 +11,3 @@
// switch engine to AllreduceMock // switch engine to AllreduceMock
#define RABIT_USE_BASE #define RABIT_USE_BASE
#include "engine.cc" #include "engine.cc"

View File

@ -1,132 +0,0 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine_empty.cc
* \brief this file provides a dummy implementation of engine that does nothing
* this file provides a way to fall back to single node program without causing too many dependencies
* This is usually NOT needed, use engine_mpi or engine for real distributed version
* \author Tianqi Chen
*/
#define NOMINMAX
#include <rabit/base.h>
#include "rabit/internal/engine.h"
namespace rabit {
namespace engine {
/*! \brief EmptyEngine */
class EmptyEngine : public IEngine {
public:
EmptyEngine() {
version_number_ = 0;
}
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
size_t slice_end, size_t size_prev_slice, const char *_file,
const int _line, const char *_caller) override {
utils::Error("EmptyEngine:: Allgather is not supported");
}
int GetRingPrevRank() const override {
utils::Error("EmptyEngine:: GetRingPrevRank is not supported");
return -1;
}
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
ReduceFunction reducer, PreprocFunction prepare_fun,
void *prepare_arg, const char *_file, const int _line,
const char *_caller) override {
utils::Error("EmptyEngine:: Allreduce is not supported,"\
"use Allreduce_ instead");
}
void Broadcast(void *sendrecvbuf_, size_t size, int root,
const char* _file, const int _line, const char* _caller) override {
}
void InitAfterException() override {
utils::Error("EmptyEngine is not fault tolerant");
}
int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = nullptr) override {
return 0;
}
void CheckPoint(const Serializable *global_model,
const Serializable *local_model = nullptr) override {
version_number_ += 1;
}
void LazyCheckPoint(const Serializable *global_model) override {
version_number_ += 1;
}
int VersionNumber() const override {
return version_number_;
}
/*! \brief get rank of current node */
int GetRank() const override {
return 0;
}
/*! \brief get total number of */
int GetWorldSize() const override {
return 1;
}
/*! \brief whether it is distributed */
bool IsDistributed() const override {
return false;
}
/*! \brief get the host name of current node */
std::string GetHost() const override {
return std::string("");
}
void TrackerPrint(const std::string &msg) override {
// simply print information into the tracker
utils::Printf("%s", msg.c_str());
}
private:
int version_number_;
};
// singleton sync manager
EmptyEngine manager;
/*! \brief intiialize the synchronization module */
bool Init(int argc, char *argv[]) {
return true;
}
/*! \brief finalize syncrhonization module */
bool Finalize() {
return true;
}
/*! \brief singleton method to get engine */
IEngine *GetEngine() {
return &manager;
}
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf,
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
mpi::DataType dtype,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg,
const char* _file,
const int _line,
const char* _caller) {
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
}
// code for reduce handle
ReduceHandle::ReduceHandle() = default;
ReduceHandle::~ReduceHandle() = default;
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return 0;
}
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {}
void ReduceHandle::Allreduce(void *sendrecvbuf,
size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg,
const char* _file,
const int _line,
const char* _caller) {
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
}
} // namespace engine
} // namespace rabit

View File

@ -26,7 +26,6 @@ def get_mingw_bin():
def test_with_autotools(args): def test_with_autotools(args):
with DirectoryExcursion(r_package): with DirectoryExcursion(r_package):
if args.compiler == 'mingw':
mingw_bin = get_mingw_bin() mingw_bin = get_mingw_bin()
CXX = os.path.join(mingw_bin, 'g++.exe') CXX = os.path.join(mingw_bin, 'g++.exe')
CC = os.path.join(mingw_bin, 'gcc.exe') CC = os.path.join(mingw_bin, 'gcc.exe')
@ -34,14 +33,6 @@ def test_with_autotools(args):
env = os.environ.copy() env = os.environ.copy()
env.update({'CC': CC, 'CXX': CXX}) env.update({'CC': CC, 'CXX': CXX})
subprocess.check_call(cmd, env=env) subprocess.check_call(cmd, env=env)
elif args.compiler == 'msvc':
cmd = ['R.exe', 'CMD', 'INSTALL', str(os.path.curdir)]
env = os.environ.copy()
# autotool favor mingw by default.
env.update({'CC': 'cl.exe', 'CXX': 'cl.exe'})
subprocess.check_call(cmd, env=env)
else:
raise ValueError('Wrong compiler')
subprocess.check_call([ subprocess.check_call([
'R.exe', '-q', '-e', 'R.exe', '-q', '-e',
"library(testthat); setwd('tests'); source('testthat.R')" "library(testthat); setwd('tests'); source('testthat.R')"