Enable building rabit on Windows (#6105)
This commit is contained in:
parent
08bdb2efc8
commit
c92d751ad1
3
.github/workflows/main.yml
vendored
3
.github/workflows/main.yml
vendored
@ -203,9 +203,8 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- {os: windows-2016, r: 'release', compiler: 'msvc', build: 'autotools'}
|
||||
- {os: windows-2016, r: 'release', compiler: 'msvc', build: 'cmake'}
|
||||
- {os: windows-2016, r: 'release', compiler: 'mingw', build: 'autotools'}
|
||||
- {os: windows-2016, r: 'release', compiler: 'msvc', build: 'cmake'}
|
||||
- {os: windows-2016, r: 'release', compiler: 'mingw', build: 'cmake'}
|
||||
env:
|
||||
R_REMOTES_NO_ERRORS_FROM_WARNINGS: true
|
||||
|
||||
@ -177,7 +177,7 @@ else()
|
||||
-D_CRT_SECURE_NO_WARNINGS -D_CRT_SECURE_NO_DEPRECATE)
|
||||
endif (MSVC)
|
||||
endif(RABIT_MOCK)
|
||||
foreach(lib rabit rabit_base rabit_empty rabit_mock rabit_mock_static)
|
||||
foreach(lib rabit rabit_base rabit_mock rabit_mock_static)
|
||||
# Explicitly link dmlc to rabit, so that configured header (build_config.h)
|
||||
# from dmlc is correctly applied to rabit.
|
||||
if (TARGET ${lib})
|
||||
|
||||
@ -52,4 +52,3 @@ AC_SUBST(ENDIAN_FLAG)
|
||||
AC_SUBST(BACKTRACE_LIB)
|
||||
AC_CONFIG_FILES([src/Makevars])
|
||||
AC_OUTPUT
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ CXX_STD = CXX14
|
||||
XGB_RFLAGS = -DXGBOOST_STRICT_R_MODE=1 -DDMLC_LOG_BEFORE_THROW=0\
|
||||
-DDMLC_ENABLE_STD_THREAD=$(ENABLE_STD_THREAD) -DDMLC_DISABLE_STDIN=1\
|
||||
-DDMLC_LOG_CUSTOMIZE=1 -DXGBOOST_CUSTOMIZE_LOGGER=1\
|
||||
-DRABIT_CUSTOMIZE_MSG_ -DRABIT_STRICT_CXX98_
|
||||
-DRABIT_CUSTOMIZE_MSG_
|
||||
|
||||
# disable the use of thread_local for 32 bit windows:
|
||||
ifeq ($(R_OSTYPE)$(WIN),windows)
|
||||
@ -19,6 +19,7 @@ $(foreach v, $(XGB_RFLAGS), $(warning $(v)))
|
||||
PKG_CPPFLAGS= -I$(PKGROOT)/include -I$(PKGROOT)/dmlc-core/include -I$(PKGROOT)/rabit/include -I$(PKGROOT) $(XGB_RFLAGS)
|
||||
PKG_CXXFLAGS= @OPENMP_CXXFLAGS@ @ENDIAN_FLAG@ -pthread
|
||||
PKG_LIBS = @OPENMP_CXXFLAGS@ @OPENMP_LIB@ @ENDIAN_FLAG@ @BACKTRACE_LIB@ -pthread
|
||||
OBJECTS= ./xgboost_R.o ./xgboost_custom.o ./xgboost_assert.o ./init.o\
|
||||
$(PKGROOT)/amalgamation/xgboost-all0.o $(PKGROOT)/amalgamation/dmlc-minimum0.o\
|
||||
$(PKGROOT)/rabit/src/engine_empty.o $(PKGROOT)/rabit/src/c_api.o
|
||||
OBJECTS= ./xgboost_R.o ./xgboost_custom.o ./xgboost_assert.o ./init.o \
|
||||
$(PKGROOT)/amalgamation/xgboost-all0.o $(PKGROOT)/amalgamation/dmlc-minimum0.o \
|
||||
$(PKGROOT)/rabit/src/engine.o $(PKGROOT)/rabit/src/c_api.o \
|
||||
$(PKGROOT)/rabit/src/allreduce_base.o $(PKGROOT)/rabit/src/allreduce_robust.o
|
||||
|
||||
@ -20,7 +20,7 @@ CXX_STD = CXX14
|
||||
XGB_RFLAGS = -DXGBOOST_STRICT_R_MODE=1 -DDMLC_LOG_BEFORE_THROW=0\
|
||||
-DDMLC_ENABLE_STD_THREAD=$(ENABLE_STD_THREAD) -DDMLC_DISABLE_STDIN=1\
|
||||
-DDMLC_LOG_CUSTOMIZE=1 -DXGBOOST_CUSTOMIZE_LOGGER=1\
|
||||
-DRABIT_CUSTOMIZE_MSG_ -DRABIT_STRICT_CXX98_
|
||||
-DRABIT_CUSTOMIZE_MSG_
|
||||
|
||||
# disable the use of thread_local for 32 bit windows:
|
||||
ifeq ($(R_OSTYPE)$(WIN),windows)
|
||||
@ -31,8 +31,9 @@ $(foreach v, $(XGB_RFLAGS), $(warning $(v)))
|
||||
PKG_CPPFLAGS= -I$(PKGROOT)/include -I$(PKGROOT)/dmlc-core/include -I$(PKGROOT)/rabit/include -I$(PKGROOT) $(XGB_RFLAGS)
|
||||
PKG_CXXFLAGS= $(SHLIB_OPENMP_CXXFLAGS) $(SHLIB_PTHREAD_FLAGS)
|
||||
PKG_LIBS = $(SHLIB_OPENMP_CXXFLAGS) $(SHLIB_PTHREAD_FLAGS)
|
||||
OBJECTS= ./xgboost_R.o ./xgboost_custom.o ./xgboost_assert.o ./init.o\
|
||||
$(PKGROOT)/amalgamation/xgboost-all0.o $(PKGROOT)/amalgamation/dmlc-minimum0.o\
|
||||
$(PKGROOT)/rabit/src/engine_empty.o $(PKGROOT)/rabit/src/c_api.o
|
||||
OBJECTS= ./xgboost_R.o ./xgboost_custom.o ./xgboost_assert.o ./init.o \
|
||||
$(PKGROOT)/amalgamation/xgboost-all0.o $(PKGROOT)/amalgamation/dmlc-minimum0.o \
|
||||
$(PKGROOT)/rabit/src/engine.o $(PKGROOT)/rabit/src/c_api.o \
|
||||
$(PKGROOT)/rabit/src/allreduce_base.o $(PKGROOT)/rabit/src/allreduce_robust.o
|
||||
|
||||
$(OBJECTS) : xgblib
|
||||
|
||||
@ -13,23 +13,6 @@ void CustomLogMessage::Log(const std::string& msg) {
|
||||
}
|
||||
} // namespace dmlc
|
||||
|
||||
// implements rabit error handling.
|
||||
extern "C" {
|
||||
void XGBoostAssert_R(int exp, const char *fmt, ...);
|
||||
void XGBoostCheck_R(int exp, const char *fmt, ...);
|
||||
}
|
||||
|
||||
namespace rabit {
|
||||
namespace utils {
|
||||
extern "C" {
|
||||
void (*Printf)(const char *fmt, ...) = Rprintf;
|
||||
void (*Assert)(int exp, const char *fmt, ...) = XGBoostAssert_R;
|
||||
void (*Check)(int exp, const char *fmt, ...) = XGBoostCheck_R;
|
||||
void (*Error)(const char *fmt, ...) = error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace xgboost {
|
||||
ConsoleLogger::~ConsoleLogger() {
|
||||
if (cur_verbosity_ == LogVerbosity::kIgnore ||
|
||||
|
||||
@ -12,5 +12,3 @@
|
||||
#include "../dmlc-core/src/data.cc"
|
||||
#include "../dmlc-core/src/io.cc"
|
||||
#include "../dmlc-core/src/recordio.cc"
|
||||
|
||||
|
||||
|
||||
@ -1,28 +1,19 @@
|
||||
cmake_minimum_required(VERSION 3.3)
|
||||
|
||||
if(R_LIB OR MINGW OR WIN32)
|
||||
add_library(rabit src/engine_empty.cc src/c_api.cc)
|
||||
set(rabit_libs rabit)
|
||||
set_target_properties(rabit
|
||||
PROPERTIES CXX_STANDARD 14
|
||||
CXX_STANDARD_REQUIRED ON
|
||||
POSITION_INDEPENDENT_CODE ON)
|
||||
else()
|
||||
find_package(Threads REQUIRED)
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
add_library(rabit src/allreduce_base.cc src/allreduce_robust.cc src/engine.cc src/c_api.cc)
|
||||
add_library(rabit_mock_static src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc)
|
||||
add_library(rabit_mock SHARED src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc)
|
||||
target_link_libraries(rabit Threads::Threads dmlc)
|
||||
target_link_libraries(rabit_mock_static Threads::Threads dmlc)
|
||||
target_link_libraries(rabit_mock Threads::Threads dmlc)
|
||||
add_library(rabit src/allreduce_base.cc src/allreduce_robust.cc src/engine.cc src/c_api.cc)
|
||||
add_library(rabit_mock_static src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc)
|
||||
add_library(rabit_mock SHARED src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc)
|
||||
target_link_libraries(rabit Threads::Threads dmlc)
|
||||
target_link_libraries(rabit_mock_static Threads::Threads dmlc)
|
||||
target_link_libraries(rabit_mock Threads::Threads dmlc)
|
||||
|
||||
set(rabit_libs rabit rabit_mock rabit_mock_static)
|
||||
set_target_properties(rabit rabit_mock rabit_mock_static
|
||||
PROPERTIES CXX_STANDARD 14
|
||||
CXX_STANDARD_REQUIRED ON
|
||||
POSITION_INDEPENDENT_CODE ON)
|
||||
endif(R_LIB OR MINGW OR WIN32)
|
||||
set(rabit_libs rabit rabit_mock rabit_mock_static)
|
||||
set_target_properties(rabit rabit_mock rabit_mock_static
|
||||
PROPERTIES CXX_STANDARD 14
|
||||
CXX_STANDARD_REQUIRED ON
|
||||
POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
if(RABIT_BUILD_MPI)
|
||||
find_package(MPI REQUIRED)
|
||||
|
||||
104
rabit/Makefile
104
rabit/Makefile
@ -1,104 +0,0 @@
|
||||
OS := $(shell uname)
|
||||
|
||||
RABIT_BUILD_DMLC = 0
|
||||
|
||||
export WARNFLAGS= -Wall -Wextra -Wno-unused-parameter -Wno-unknown-pragmas -std=c++11
|
||||
export CFLAGS = -O3 $(WARNFLAGS)
|
||||
export LDFLAGS =-Llib
|
||||
|
||||
#download mpi
|
||||
#echo $(shell scripts/mpi.sh)
|
||||
|
||||
MPICXX=./mpich/bin/mpicxx
|
||||
|
||||
export CXX = g++
|
||||
|
||||
|
||||
#----------------------------
|
||||
# Settings for power and arm arch
|
||||
#----------------------------
|
||||
ARCH := $(shell uname -a)
|
||||
ifneq (,$(filter $(ARCH), armv6l armv7l powerpc64le ppc64le aarch64))
|
||||
CFLAGS += -march=native
|
||||
else
|
||||
CFLAGS += -msse2
|
||||
endif
|
||||
|
||||
ifndef WITH_FPIC
|
||||
WITH_FPIC = 1
|
||||
endif
|
||||
ifeq ($(WITH_FPIC), 1)
|
||||
CFLAGS += -fPIC
|
||||
endif
|
||||
|
||||
ifndef LINT_LANG
|
||||
LINT_LANG="all"
|
||||
endif
|
||||
|
||||
ifeq ($(RABIT_BUILD_DMLC),1)
|
||||
DMLC=dmlc-core
|
||||
else
|
||||
DMLC=../dmlc-core
|
||||
endif
|
||||
|
||||
CFLAGS += -I $(DMLC)/include -I include/
|
||||
|
||||
# build path
|
||||
BPATH=.
|
||||
# objectives that makes up rabit library
|
||||
MPIOBJ= $(BPATH)/engine_mpi.o
|
||||
OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(BPATH)/engine_empty.o $(BPATH)/engine_mock.o\
|
||||
$(BPATH)/c_api.o $(BPATH)/engine_base.o
|
||||
SLIB= lib/librabit.so lib/librabit_mock.so lib/librabit_base.so
|
||||
ALIB= lib/librabit.a lib/librabit_empty.a lib/librabit_mock.a lib/librabit_base.a
|
||||
MPISLIB= lib/librabit_mpi.so
|
||||
MPIALIB= lib/librabit_mpi.a
|
||||
HEADERS=src/*.h include/rabit/*.h include/rabit/internal/*.h
|
||||
|
||||
.PHONY: clean all install mpi python lint doc doxygen
|
||||
|
||||
all: lib/librabit.a lib/librabit_mock.a lib/librabit.so lib/librabit_base.a lib/librabit_mock.so
|
||||
mpi: lib/librabit_mpi.a lib/librabit_mpi.so
|
||||
|
||||
$(BPATH)/allreduce_base.o: src/allreduce_base.cc $(HEADERS)
|
||||
$(BPATH)/engine.o: src/engine.cc $(HEADERS)
|
||||
$(BPATH)/allreduce_robust.o: src/allreduce_robust.cc $(HEADERS)
|
||||
$(BPATH)/engine_mpi.o: src/engine_mpi.cc $(HEADERS)
|
||||
$(BPATH)/engine_empty.o: src/engine_empty.cc $(HEADERS)
|
||||
$(BPATH)/engine_mock.o: src/engine_mock.cc $(HEADERS)
|
||||
$(BPATH)/engine_base.o: src/engine_base.cc $(HEADERS)
|
||||
$(BPATH)/c_api.o: src/c_api.cc $(HEADERS)
|
||||
|
||||
lib/librabit.a lib/librabit.so: $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(BPATH)/c_api.o
|
||||
lib/librabit_base.a lib/librabit_base.so: $(BPATH)/allreduce_base.o $(BPATH)/engine_base.o $(BPATH)/c_api.o
|
||||
lib/librabit_mock.a lib/librabit_mock.so: $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine_mock.o $(BPATH)/c_api.o
|
||||
lib/librabit_empty.a: $(BPATH)/engine_empty.o $(BPATH)/c_api.o
|
||||
lib/librabit_mpi.a lib/librabit_mpi.so: $(MPIOBJ)
|
||||
|
||||
$(OBJ) :
|
||||
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
||||
|
||||
$(ALIB):
|
||||
ar cr $@ $+
|
||||
|
||||
$(SLIB) :
|
||||
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS)
|
||||
|
||||
$(MPIOBJ) :
|
||||
$(MPICXX) -c $(CFLAGS) -I./mpich/include -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
||||
|
||||
$(MPIALIB):
|
||||
ar cr $@ $+
|
||||
|
||||
$(MPISLIB) :
|
||||
$(MPICXX) $(CFLAGS) -I./mpich/include -shared -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) \
|
||||
$(LDFLAGS) -L./mpich/lib -Wl,-rpath,./mpich/lib -lmpi
|
||||
|
||||
lint:
|
||||
$(DMLC)/scripts/lint.py rabit $(LINT_LANG) src include
|
||||
|
||||
doc doxygen:
|
||||
cd include; doxygen ../doc/Doxyfile; cd -
|
||||
|
||||
clean:
|
||||
$(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) $(SLIB) *~ src/*~ include/*~ include/*/*~
|
||||
@ -9,10 +9,13 @@
|
||||
#if defined(_WIN32)
|
||||
#include <winsock2.h>
|
||||
#include <ws2tcpip.h>
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#pragma comment(lib, "Ws2_32.lib")
|
||||
#endif // _MSC_VER
|
||||
|
||||
#else
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <netdb.h>
|
||||
#include <cerrno>
|
||||
@ -21,31 +24,92 @@
|
||||
#include <netinet/in.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/ioctl.h>
|
||||
|
||||
#endif // defined(_WIN32)
|
||||
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include "utils.h"
|
||||
|
||||
#if defined(_WIN32) || defined(__MINGW32__)
|
||||
#if defined(_WIN32) && !defined(__MINGW32__)
|
||||
typedef int ssize_t;
|
||||
#endif // defined(_WIN32) || defined(__MINGW32__)
|
||||
|
||||
#if defined(_WIN32)
|
||||
typedef int sock_size_t;
|
||||
using sock_size_t = int;
|
||||
|
||||
static inline int poll(struct pollfd *pfd, int nfds,
|
||||
int timeout) { return WSAPoll ( pfd, nfds, timeout ); }
|
||||
#else
|
||||
|
||||
#include <sys/poll.h>
|
||||
using SOCKET = int;
|
||||
using sock_size_t = size_t; // NOLINT
|
||||
const int kInvalidSocket = -1;
|
||||
#endif // defined(_WIN32)
|
||||
|
||||
#define IS_MINGW() defined(__MINGW32__)
|
||||
|
||||
#if IS_MINGW()
|
||||
inline void MingWError() {
|
||||
throw dmlc::Error("Distributed training on mingw is not supported.");
|
||||
}
|
||||
#endif // IS_MINGW()
|
||||
|
||||
#if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
|
||||
/*
|
||||
* On later mingw versions poll should be supported (with bugs). See:
|
||||
* https://stackoverflow.com/a/60623080
|
||||
*
|
||||
* But right now the mingw distributed with R 3.6 doesn't support it.
|
||||
* So we just give a warning and provide dummy implementation to get
|
||||
* compilation passed. Otherwise we will have to provide a stub for
|
||||
* RABIT.
|
||||
*
|
||||
* Even on mingw version that has these structures and flags defined,
|
||||
* functions like `send` and `listen` might have unresolved linkage to
|
||||
* their implementation. So supporting mingw is quite difficult at
|
||||
* the time of writing.
|
||||
*/
|
||||
#pragma message("Distributed training on mingw is not supported.")
|
||||
typedef struct pollfd {
|
||||
SOCKET fd;
|
||||
short events;
|
||||
short revents;
|
||||
} WSAPOLLFD, *PWSAPOLLFD, *LPWSAPOLLFD;
|
||||
|
||||
// POLLRDNORM | POLLRDBAND
|
||||
#define POLLIN (0x0100 | 0x0200)
|
||||
#define POLLPRI 0x0400
|
||||
// POLLWRNORM
|
||||
#define POLLOUT 0x0010
|
||||
|
||||
inline const char *inet_ntop(int, const void *, char *, size_t) {
|
||||
MingWError();
|
||||
return nullptr;
|
||||
}
|
||||
#endif // IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
|
||||
|
||||
namespace rabit {
|
||||
namespace utils {
|
||||
|
||||
static constexpr int kInvalidSocket = -1;
|
||||
|
||||
template <typename PollFD>
|
||||
int PollImpl(PollFD *pfd, int nfds, int timeout) {
|
||||
#if defined(_WIN32)
|
||||
|
||||
#if IS_MINGW()
|
||||
MingWError();
|
||||
return -1;
|
||||
#else
|
||||
return WSAPoll(pfd, nfds, timeout);
|
||||
#endif // IS_MINGW()
|
||||
|
||||
#else
|
||||
return poll(pfd, nfds, timeout);
|
||||
#endif // IS_MINGW()
|
||||
}
|
||||
|
||||
/*! \brief data structure for network address */
|
||||
struct SockAddr {
|
||||
sockaddr_in addr;
|
||||
@ -56,7 +120,9 @@ struct SockAddr {
|
||||
}
|
||||
inline static std::string GetHostName() {
|
||||
std::string buf; buf.resize(256);
|
||||
#if !IS_MINGW()
|
||||
utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
|
||||
#endif // IS_MINGW()
|
||||
return std::string(buf.c_str());
|
||||
}
|
||||
/*!
|
||||
@ -65,6 +131,7 @@ struct SockAddr {
|
||||
* \param port the port of address
|
||||
*/
|
||||
inline void Set(const char *host, int port) {
|
||||
#if !IS_MINGW()
|
||||
addrinfo hints;
|
||||
memset(&hints, 0, sizeof(hints));
|
||||
hints.ai_family = AF_INET;
|
||||
@ -76,6 +143,7 @@ struct SockAddr {
|
||||
memcpy(&addr, res->ai_addr, res->ai_addrlen);
|
||||
addr.sin_port = htons(port);
|
||||
freeaddrinfo(res);
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*! \brief return port of the address*/
|
||||
inline int Port() const {
|
||||
@ -112,7 +180,14 @@ class Socket {
|
||||
*/
|
||||
inline static int GetLastError() {
|
||||
#ifdef _WIN32
|
||||
|
||||
#if IS_MINGW()
|
||||
MingWError();
|
||||
return -1;
|
||||
#else
|
||||
return WSAGetLastError();
|
||||
#endif // IS_MINGW()
|
||||
|
||||
#else
|
||||
return errno;
|
||||
#endif // _WIN32
|
||||
@ -132,14 +207,16 @@ class Socket {
|
||||
*/
|
||||
inline static void Startup() {
|
||||
#ifdef _WIN32
|
||||
#if !IS_MINGW()
|
||||
WSADATA wsa_data;
|
||||
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
|
||||
Socket::Error("Startup");
|
||||
Socket::Error("Startup");
|
||||
}
|
||||
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
|
||||
WSACleanup();
|
||||
utils::Error("Could not find a usable version of Winsock.dll\n");
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
#endif // _WIN32
|
||||
}
|
||||
/*!
|
||||
@ -147,7 +224,9 @@ class Socket {
|
||||
*/
|
||||
inline static void Finalize() {
|
||||
#ifdef _WIN32
|
||||
#if !IS_MINGW()
|
||||
WSACleanup();
|
||||
#endif // !IS_MINGW()
|
||||
#endif // _WIN32
|
||||
}
|
||||
/*!
|
||||
@ -157,10 +236,12 @@ class Socket {
|
||||
*/
|
||||
inline void SetNonBlock(bool non_block) {
|
||||
#ifdef _WIN32
|
||||
#if !IS_MINGW()
|
||||
u_long mode = non_block ? 1 : 0;
|
||||
if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
|
||||
Socket::Error("SetNonBlock");
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
#else
|
||||
int flag = fcntl(sockfd, F_GETFL, 0);
|
||||
if (flag == -1) {
|
||||
@ -181,10 +262,12 @@ class Socket {
|
||||
* \param addr
|
||||
*/
|
||||
inline void Bind(const SockAddr &addr) {
|
||||
#if !IS_MINGW()
|
||||
if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
|
||||
sizeof(addr.addr)) == -1) {
|
||||
Socket::Error("Bind");
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief try bind the socket to host, from start_port to end_port
|
||||
@ -194,6 +277,7 @@ class Socket {
|
||||
*/
|
||||
inline int TryBindHost(int start_port, int end_port) {
|
||||
// TODO(tqchen) add prefix check
|
||||
#if !IS_MINGW()
|
||||
for (int port = start_port; port < end_port; ++port) {
|
||||
SockAddr addr("0.0.0.0", port);
|
||||
if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
|
||||
@ -210,17 +294,22 @@ class Socket {
|
||||
}
|
||||
#endif // defined(_WIN32)
|
||||
}
|
||||
|
||||
#endif // !IS_MINGW()
|
||||
return -1;
|
||||
}
|
||||
/*! \brief get last error code if any */
|
||||
inline int GetSockError() const {
|
||||
int error = 0;
|
||||
socklen_t len = sizeof(error);
|
||||
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
|
||||
reinterpret_cast<char*>(&error), &len) != 0) {
|
||||
#if !IS_MINGW()
|
||||
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
|
||||
reinterpret_cast<char *>(&error), &len) != 0) {
|
||||
Error("GetSockError");
|
||||
}
|
||||
#else
|
||||
// undefined reference to `_imp__getsockopt@20'
|
||||
MingWError();
|
||||
#endif // !IS_MINGW()
|
||||
return error;
|
||||
}
|
||||
/*! \brief check if anything bad happens */
|
||||
@ -238,7 +327,9 @@ class Socket {
|
||||
inline void Close() {
|
||||
if (sockfd != kInvalidSocket) {
|
||||
#ifdef _WIN32
|
||||
#if !IS_MINGW()
|
||||
closesocket(sockfd);
|
||||
#endif // !IS_MINGW()
|
||||
#else
|
||||
close(sockfd);
|
||||
#endif
|
||||
@ -277,50 +368,64 @@ class TCPSocket : public Socket{
|
||||
* \param keepalive whether to set the keep alive option on
|
||||
*/
|
||||
void SetKeepAlive(bool keepalive) {
|
||||
#if !IS_MINGW()
|
||||
int opt = static_cast<int>(keepalive);
|
||||
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE,
|
||||
reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
|
||||
Socket::Error("SetKeepAlive");
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
inline void SetLinger(int timeout = 0) {
|
||||
#if !IS_MINGW()
|
||||
struct linger sl;
|
||||
sl.l_onoff = 1; /* non-zero value enables linger option in kernel */
|
||||
sl.l_linger = timeout; /* timeout interval in seconds */
|
||||
if (setsockopt(sockfd, SOL_SOCKET, SO_LINGER, reinterpret_cast<char*>(&sl), sizeof(sl)) == -1) {
|
||||
Socket::Error("SO_LINGER");
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief create the socket, call this before using socket
|
||||
* \param af domain
|
||||
*/
|
||||
inline void Create(int af = PF_INET) {
|
||||
#if !IS_MINGW()
|
||||
sockfd = socket(PF_INET, SOCK_STREAM, 0);
|
||||
if (sockfd == kInvalidSocket) {
|
||||
Socket::Error("Create");
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief perform listen of the socket
|
||||
* \param backlog backlog parameter
|
||||
*/
|
||||
inline void Listen(int backlog = 16) {
|
||||
#if !IS_MINGW()
|
||||
listen(sockfd, backlog);
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*! \brief get a new connection */
|
||||
TCPSocket Accept() {
|
||||
#if !IS_MINGW()
|
||||
SOCKET newfd = accept(sockfd, nullptr, nullptr);
|
||||
if (newfd == kInvalidSocket) {
|
||||
Socket::Error("Accept");
|
||||
}
|
||||
return TCPSocket(newfd);
|
||||
#else
|
||||
return TCPSocket();
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief decide whether the socket is at OOB mark
|
||||
* \return 1 if at mark, 0 if not, -1 if an error occured
|
||||
*/
|
||||
inline int AtMark() const {
|
||||
#if !IS_MINGW()
|
||||
|
||||
#ifdef _WIN32
|
||||
unsigned long atmark; // NOLINT(*)
|
||||
if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
|
||||
@ -328,7 +433,12 @@ class TCPSocket : public Socket{
|
||||
int atmark;
|
||||
if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1;
|
||||
#endif // _WIN32
|
||||
|
||||
return static_cast<int>(atmark);
|
||||
|
||||
#else
|
||||
return -1;
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief connect to an address
|
||||
@ -336,8 +446,12 @@ class TCPSocket : public Socket{
|
||||
* \return whether connect is successful
|
||||
*/
|
||||
inline bool Connect(const SockAddr &addr) {
|
||||
#if !IS_MINGW()
|
||||
return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
|
||||
sizeof(addr.addr)) == 0;
|
||||
#else
|
||||
return false;
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief send data using the socket
|
||||
@ -349,7 +463,11 @@ class TCPSocket : public Socket{
|
||||
*/
|
||||
inline ssize_t Send(const void *buf_, size_t len, int flag = 0) {
|
||||
const char *buf = reinterpret_cast<const char*>(buf_);
|
||||
#if !IS_MINGW()
|
||||
return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
|
||||
#else
|
||||
return 0;
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief receive data using the socket
|
||||
@ -361,7 +479,11 @@ class TCPSocket : public Socket{
|
||||
*/
|
||||
inline ssize_t Recv(void *buf_, size_t len, int flags = 0) {
|
||||
char *buf = reinterpret_cast<char*>(buf_);
|
||||
#if !IS_MINGW()
|
||||
return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
|
||||
#else
|
||||
return 0;
|
||||
#endif // !IS_MINGW()
|
||||
}
|
||||
/*!
|
||||
* \brief peform block write that will attempt to send all data out
|
||||
@ -373,6 +495,7 @@ class TCPSocket : public Socket{
|
||||
inline size_t SendAll(const void *buf_, size_t len) {
|
||||
const char *buf = reinterpret_cast<const char*>(buf_);
|
||||
size_t ndone = 0;
|
||||
#if !IS_MINGW()
|
||||
while (ndone < len) {
|
||||
ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0);
|
||||
if (ret == -1) {
|
||||
@ -382,6 +505,7 @@ class TCPSocket : public Socket{
|
||||
buf += ret;
|
||||
ndone += ret;
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
return ndone;
|
||||
}
|
||||
/*!
|
||||
@ -394,6 +518,7 @@ class TCPSocket : public Socket{
|
||||
inline size_t RecvAll(void *buf_, size_t len) {
|
||||
char *buf = reinterpret_cast<char*>(buf_);
|
||||
size_t ndone = 0;
|
||||
#if !IS_MINGW()
|
||||
while (ndone < len) {
|
||||
ssize_t ret = recv(sockfd, buf,
|
||||
static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
|
||||
@ -405,6 +530,7 @@ class TCPSocket : public Socket{
|
||||
buf += ret;
|
||||
ndone += ret;
|
||||
}
|
||||
#endif // !IS_MINGW()
|
||||
return ndone;
|
||||
}
|
||||
/*!
|
||||
@ -500,7 +626,7 @@ struct PollHelper {
|
||||
pollfd pfd;
|
||||
pfd.fd = fd;
|
||||
pfd.events = POLLPRI;
|
||||
return poll(&pfd, 1, timeout);
|
||||
return PollImpl(&pfd, 1, timeout);
|
||||
}
|
||||
|
||||
/*!
|
||||
@ -514,7 +640,7 @@ struct PollHelper {
|
||||
for (auto kv : fds) {
|
||||
fdset.push_back(kv.second);
|
||||
}
|
||||
int ret = poll(fdset.data(), fdset.size(), timeout);
|
||||
int ret = PollImpl(fdset.data(), fdset.size(), timeout);
|
||||
if (ret == -1) {
|
||||
Socket::Error("Poll");
|
||||
} else {
|
||||
@ -533,4 +659,11 @@ struct PollHelper {
|
||||
};
|
||||
} // namespace utils
|
||||
} // namespace rabit
|
||||
|
||||
#if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
|
||||
#undef POLLIN
|
||||
#undef POLLPRI
|
||||
#undef POLLOUT
|
||||
#endif // IS_MINGW()
|
||||
|
||||
#endif // RABIT_INTERNAL_SOCKET_H_
|
||||
|
||||
@ -15,10 +15,7 @@
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#include "dmlc/io.h"
|
||||
|
||||
#ifndef RABIT_STRICT_CXX98_
|
||||
#include <cstdarg>
|
||||
#endif // RABIT_STRICT_CXX98_
|
||||
|
||||
#if !defined(__GNUC__) || defined(__FreeBSD__)
|
||||
#define fopen64 std::fopen
|
||||
@ -71,7 +68,6 @@ inline bool StringToBool(const char* s) {
|
||||
return CompareStringsCaseInsensitive(s, "true") == 0 || atoi(s) != 0;
|
||||
}
|
||||
|
||||
#ifndef RABIT_CUSTOMIZE_MSG_
|
||||
/*!
|
||||
* \brief handling of Assert error, caused by inappropriate input
|
||||
* \param msg error message
|
||||
@ -89,6 +85,7 @@ inline void HandleCheckError(const char *msg) {
|
||||
fprintf(stderr, "%s, rabit is configured to keep process running\n", msg);
|
||||
throw dmlc::Error(msg);
|
||||
}
|
||||
|
||||
inline void HandlePrint(const char *msg) {
|
||||
printf("%s", msg);
|
||||
}
|
||||
@ -102,22 +99,7 @@ inline void HandleLogInfo(const char *fmt, ...) {
|
||||
fprintf(stdout, "%s", msg.c_str());
|
||||
fflush(stdout);
|
||||
}
|
||||
#else
|
||||
#ifndef RABIT_STRICT_CXX98_
|
||||
// include declarations, some one must implement this
|
||||
void HandleAssertError(const char *msg);
|
||||
void HandleCheckError(const char *msg);
|
||||
void HandlePrint(const char *msg);
|
||||
#endif // RABIT_STRICT_CXX98_
|
||||
#endif // RABIT_CUSTOMIZE_MSG_
|
||||
#ifdef RABIT_STRICT_CXX98_
|
||||
// these function pointers are to be assigned
|
||||
extern "C" void (*Printf)(const char *fmt, ...);
|
||||
extern "C" int (*SPrintf)(char *buf, size_t size, const char *fmt, ...);
|
||||
extern "C" void (*Assert)(int exp, const char *fmt, ...);
|
||||
extern "C" void (*Check)(int exp, const char *fmt, ...);
|
||||
extern "C" void (*Error)(const char *fmt, ...);
|
||||
#else
|
||||
|
||||
/*! \brief printf, prints messages to the console */
|
||||
inline void Printf(const char *fmt, ...) {
|
||||
std::string msg(kPrintBuffer, '\0');
|
||||
@ -127,6 +109,7 @@ inline void Printf(const char *fmt, ...) {
|
||||
va_end(args);
|
||||
HandlePrint(msg.c_str());
|
||||
}
|
||||
|
||||
/*! \brief portable version of snprintf */
|
||||
inline int SPrintf(char *buf, size_t size, const char *fmt, ...) {
|
||||
va_list args;
|
||||
@ -171,7 +154,6 @@ inline void Error(const char *fmt, ...) {
|
||||
HandleCheckError(msg.c_str());
|
||||
}
|
||||
}
|
||||
#endif // RABIT_STRICT_CXX98_
|
||||
|
||||
/*! \brief replace fopen, report error when the file open fails */
|
||||
inline std::FILE *FopenCheck(const char *fname, const char *flag) {
|
||||
@ -180,6 +162,19 @@ inline std::FILE *FopenCheck(const char *fname, const char *flag) {
|
||||
return fp;
|
||||
}
|
||||
} // namespace utils
|
||||
|
||||
// Can not use std::min on Windows with msvc due to:
|
||||
// error C2589: '(': illegal token on right side of '::'
|
||||
template <typename T>
|
||||
auto Min(T const& l, T const& r) {
|
||||
return l < r ? l : r;
|
||||
}
|
||||
// same with Min
|
||||
template <typename T>
|
||||
auto Max(T const& l, T const& r) {
|
||||
return l > r ? l : r;
|
||||
}
|
||||
|
||||
// easy utils that can be directly accessed in xgboost
|
||||
/*! \brief get the beginning address of a vector */
|
||||
template<typename T>
|
||||
|
||||
@ -8,7 +8,11 @@
|
||||
#define NOMINMAX
|
||||
#include "allreduce_base.h"
|
||||
#include <rabit/base.h>
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <netinet/tcp.h>
|
||||
#endif // _WIN32
|
||||
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
|
||||
@ -413,8 +417,12 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
all_link.sock.SetNonBlock(true);
|
||||
all_link.sock.SetKeepAlive(true);
|
||||
if (rabit_enable_tcp_no_delay) {
|
||||
#if defined(__unix__)
|
||||
setsockopt(all_link.sock, IPPROTO_TCP,
|
||||
TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay));
|
||||
#else
|
||||
fprintf(stderr, "tcp no delay is not implemented on non unix platforms\n");
|
||||
#endif
|
||||
}
|
||||
if (tree_neighbors.count(all_link.rank) != 0) {
|
||||
if (all_link.rank == parent_rank) {
|
||||
|
||||
@ -306,10 +306,11 @@ class AllreduceBase : public IEngine {
|
||||
// constructor
|
||||
LinkRecord() = default;
|
||||
// initialize buffer
|
||||
inline void InitBuffer(size_t type_nbytes, size_t count,
|
||||
size_t reduce_buffer_size) {
|
||||
void InitBuffer(size_t type_nbytes, size_t count,
|
||||
size_t reduce_buffer_size) {
|
||||
size_t n = (type_nbytes * count + 7)/ 8;
|
||||
buffer_.resize(std::min(reduce_buffer_size, n));
|
||||
auto to = Min(reduce_buffer_size, n);
|
||||
buffer_.resize(to);
|
||||
// make sure align to type_nbytes
|
||||
buffer_size =
|
||||
buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
|
||||
@ -338,8 +339,8 @@ class AllreduceBase : public IEngine {
|
||||
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
|
||||
size_t offset = size_read % buffer_size;
|
||||
size_t nmax = max_size_read - size_read;
|
||||
nmax = std::min(nmax, buffer_size - ngap);
|
||||
nmax = std::min(nmax, buffer_size - offset);
|
||||
nmax = Min(nmax, buffer_size - ngap);
|
||||
nmax = Min(nmax, buffer_size - offset);
|
||||
if (nmax == 0) return kSuccess;
|
||||
ssize_t len = sock.Recv(buffer_head + offset, nmax);
|
||||
// length equals 0, remote disconnected
|
||||
|
||||
@ -217,11 +217,11 @@ class AllreduceRobust : public AllreduceBase {
|
||||
*/
|
||||
struct ActionSummary {
|
||||
// maximumly allowed sequence id
|
||||
static const u_int32_t kSpecialOp = (1 << 26);
|
||||
static const uint32_t kSpecialOp = (1 << 26);
|
||||
// special sequence number for local state checkpoint
|
||||
static const u_int32_t kLocalCheckPoint = (1 << 26) - 2;
|
||||
static const uint32_t kLocalCheckPoint = (1 << 26) - 2;
|
||||
// special sequnce number for local state checkpoint ack signal
|
||||
static const u_int32_t kLocalCheckAck = (1 << 26) - 1;
|
||||
static const uint32_t kLocalCheckAck = (1 << 26) - 1;
|
||||
//---------------------------------------------
|
||||
// The following are bit mask of flag used in
|
||||
//----------------------------------------------
|
||||
@ -242,13 +242,13 @@ class AllreduceRobust : public AllreduceBase {
|
||||
ActionSummary() = default;
|
||||
// constructor of action
|
||||
explicit ActionSummary(int seqno_flag, int cache_flag = 0,
|
||||
u_int32_t minseqno = kSpecialOp, u_int32_t maxseqno = kSpecialOp) {
|
||||
uint32_t minseqno = kSpecialOp, uint32_t maxseqno = kSpecialOp) {
|
||||
seqcode_ = (minseqno << 5) | seqno_flag;
|
||||
maxseqcode_ = (maxseqno << 5) | cache_flag;
|
||||
}
|
||||
// minimum number of all operations by default
|
||||
// maximum number of all cache operations otherwise
|
||||
inline u_int32_t Seqno(SeqType t = SeqType::kSeq) const {
|
||||
inline uint32_t Seqno(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
|
||||
return code >> 5;
|
||||
}
|
||||
@ -294,8 +294,8 @@ class AllreduceRobust : public AllreduceBase {
|
||||
const ActionSummary *src = static_cast<const ActionSummary*>(src_);
|
||||
ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
u_int32_t min_seqno = std::min(src[i].Seqno(), dst[i].Seqno());
|
||||
u_int32_t max_seqno = std::max(src[i].Seqno(SeqType::kCache),
|
||||
uint32_t min_seqno = Min(src[i].Seqno(), dst[i].Seqno());
|
||||
uint32_t max_seqno = Max(src[i].Seqno(SeqType::kCache),
|
||||
dst[i].Seqno(SeqType::kCache));
|
||||
int action_flag = src[i].Flag() | dst[i].Flag();
|
||||
// if any node is not requester set to 0 otherwise 1
|
||||
@ -310,9 +310,9 @@ class AllreduceRobust : public AllreduceBase {
|
||||
|
||||
private:
|
||||
// internel sequence code min of rabit seqno
|
||||
u_int32_t seqcode_;
|
||||
uint32_t seqcode_;
|
||||
// internal sequence code max of cache seqno
|
||||
u_int32_t maxseqcode_;
|
||||
uint32_t maxseqcode_;
|
||||
};
|
||||
/*! \brief data structure to remember result of Bcast and Allreduce calls*/
|
||||
class ResultBuffer{
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file engine_mock.cc
|
||||
* \brief this is an engine implementation that will
|
||||
* \brief this is an engine implementation that will
|
||||
* insert failures in certain call point, to test if the engine is robust to failure
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
@ -11,4 +11,3 @@
|
||||
// switch engine to AllreduceMock
|
||||
#define RABIT_USE_BASE
|
||||
#include "engine.cc"
|
||||
|
||||
|
||||
@ -1,132 +0,0 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file engine_empty.cc
|
||||
* \brief this file provides a dummy implementation of engine that does nothing
|
||||
* this file provides a way to fall back to single node program without causing too many dependencies
|
||||
* This is usually NOT needed, use engine_mpi or engine for real distributed version
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#define NOMINMAX
|
||||
|
||||
#include <rabit/base.h>
|
||||
#include "rabit/internal/engine.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
/*! \brief EmptyEngine */
|
||||
class EmptyEngine : public IEngine {
|
||||
public:
|
||||
EmptyEngine() {
|
||||
version_number_ = 0;
|
||||
}
|
||||
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
|
||||
size_t slice_end, size_t size_prev_slice, const char *_file,
|
||||
const int _line, const char *_caller) override {
|
||||
utils::Error("EmptyEngine:: Allgather is not supported");
|
||||
}
|
||||
int GetRingPrevRank() const override {
|
||||
utils::Error("EmptyEngine:: GetRingPrevRank is not supported");
|
||||
return -1;
|
||||
}
|
||||
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
|
||||
ReduceFunction reducer, PreprocFunction prepare_fun,
|
||||
void *prepare_arg, const char *_file, const int _line,
|
||||
const char *_caller) override {
|
||||
utils::Error("EmptyEngine:: Allreduce is not supported,"\
|
||||
"use Allreduce_ instead");
|
||||
}
|
||||
void Broadcast(void *sendrecvbuf_, size_t size, int root,
|
||||
const char* _file, const int _line, const char* _caller) override {
|
||||
}
|
||||
void InitAfterException() override {
|
||||
utils::Error("EmptyEngine is not fault tolerant");
|
||||
}
|
||||
int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model = nullptr) override {
|
||||
return 0;
|
||||
}
|
||||
void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model = nullptr) override {
|
||||
version_number_ += 1;
|
||||
}
|
||||
void LazyCheckPoint(const Serializable *global_model) override {
|
||||
version_number_ += 1;
|
||||
}
|
||||
int VersionNumber() const override {
|
||||
return version_number_;
|
||||
}
|
||||
/*! \brief get rank of current node */
|
||||
int GetRank() const override {
|
||||
return 0;
|
||||
}
|
||||
/*! \brief get total number of */
|
||||
int GetWorldSize() const override {
|
||||
return 1;
|
||||
}
|
||||
/*! \brief whether it is distributed */
|
||||
bool IsDistributed() const override {
|
||||
return false;
|
||||
}
|
||||
/*! \brief get the host name of current node */
|
||||
std::string GetHost() const override {
|
||||
return std::string("");
|
||||
}
|
||||
void TrackerPrint(const std::string &msg) override {
|
||||
// simply print information into the tracker
|
||||
utils::Printf("%s", msg.c_str());
|
||||
}
|
||||
|
||||
private:
|
||||
int version_number_;
|
||||
};
|
||||
|
||||
// singleton sync manager
|
||||
EmptyEngine manager;
|
||||
|
||||
/*! \brief intiialize the synchronization module */
|
||||
bool Init(int argc, char *argv[]) {
|
||||
return true;
|
||||
}
|
||||
/*! \brief finalize syncrhonization module */
|
||||
bool Finalize() {
|
||||
return true;
|
||||
}
|
||||
|
||||
/*! \brief singleton method to get engine */
|
||||
IEngine *GetEngine() {
|
||||
return &manager;
|
||||
}
|
||||
// perform in-place allreduce, on sendrecvbuf
|
||||
void Allreduce_(void *sendrecvbuf,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
IEngine::ReduceFunction red,
|
||||
mpi::DataType dtype,
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
|
||||
}
|
||||
|
||||
// code for reduce handle
|
||||
ReduceHandle::ReduceHandle() = default;
|
||||
ReduceHandle::~ReduceHandle() = default;
|
||||
|
||||
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
||||
return 0;
|
||||
}
|
||||
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {}
|
||||
void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||
size_t type_nbytes, size_t count,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
@ -26,22 +26,13 @@ def get_mingw_bin():
|
||||
|
||||
def test_with_autotools(args):
|
||||
with DirectoryExcursion(r_package):
|
||||
if args.compiler == 'mingw':
|
||||
mingw_bin = get_mingw_bin()
|
||||
CXX = os.path.join(mingw_bin, 'g++.exe')
|
||||
CC = os.path.join(mingw_bin, 'gcc.exe')
|
||||
cmd = ['R.exe', 'CMD', 'INSTALL', str(os.path.curdir)]
|
||||
env = os.environ.copy()
|
||||
env.update({'CC': CC, 'CXX': CXX})
|
||||
subprocess.check_call(cmd, env=env)
|
||||
elif args.compiler == 'msvc':
|
||||
cmd = ['R.exe', 'CMD', 'INSTALL', str(os.path.curdir)]
|
||||
env = os.environ.copy()
|
||||
# autotool favor mingw by default.
|
||||
env.update({'CC': 'cl.exe', 'CXX': 'cl.exe'})
|
||||
subprocess.check_call(cmd, env=env)
|
||||
else:
|
||||
raise ValueError('Wrong compiler')
|
||||
mingw_bin = get_mingw_bin()
|
||||
CXX = os.path.join(mingw_bin, 'g++.exe')
|
||||
CC = os.path.join(mingw_bin, 'gcc.exe')
|
||||
cmd = ['R.exe', 'CMD', 'INSTALL', str(os.path.curdir)]
|
||||
env = os.environ.copy()
|
||||
env.update({'CC': CC, 'CXX': CXX})
|
||||
subprocess.check_call(cmd, env=env)
|
||||
subprocess.check_call([
|
||||
'R.exe', '-q', '-e',
|
||||
"library(testthat); setwd('tests'); source('testthat.R')"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user