Enable building rabit on Windows (#6105)

This commit is contained in:
Jiaming Yuan
2020-09-11 11:54:46 +08:00
committed by GitHub
parent 08bdb2efc8
commit c92d751ad1
17 changed files with 215 additions and 352 deletions

View File

@@ -8,7 +8,11 @@
#define NOMINMAX
#include "allreduce_base.h"
#include <rabit/base.h>
#ifndef _WIN32
#include <netinet/tcp.h>
#endif // _WIN32
#include <cstring>
#include <map>
@@ -413,8 +417,12 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
all_link.sock.SetNonBlock(true);
all_link.sock.SetKeepAlive(true);
if (rabit_enable_tcp_no_delay) {
#if defined(__unix__)
setsockopt(all_link.sock, IPPROTO_TCP,
TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay));
#else
fprintf(stderr, "tcp no delay is not implemented on non unix platforms\n");
#endif
}
if (tree_neighbors.count(all_link.rank) != 0) {
if (all_link.rank == parent_rank) {

View File

@@ -306,10 +306,11 @@ class AllreduceBase : public IEngine {
// constructor
LinkRecord() = default;
// initialize buffer
inline void InitBuffer(size_t type_nbytes, size_t count,
size_t reduce_buffer_size) {
void InitBuffer(size_t type_nbytes, size_t count,
size_t reduce_buffer_size) {
size_t n = (type_nbytes * count + 7)/ 8;
buffer_.resize(std::min(reduce_buffer_size, n));
auto to = Min(reduce_buffer_size, n);
buffer_.resize(to);
// make sure align to type_nbytes
buffer_size =
buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
@@ -338,8 +339,8 @@ class AllreduceBase : public IEngine {
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
size_t offset = size_read % buffer_size;
size_t nmax = max_size_read - size_read;
nmax = std::min(nmax, buffer_size - ngap);
nmax = std::min(nmax, buffer_size - offset);
nmax = Min(nmax, buffer_size - ngap);
nmax = Min(nmax, buffer_size - offset);
if (nmax == 0) return kSuccess;
ssize_t len = sock.Recv(buffer_head + offset, nmax);
// length equals 0, remote disconnected

View File

@@ -217,11 +217,11 @@ class AllreduceRobust : public AllreduceBase {
*/
struct ActionSummary {
// maximumly allowed sequence id
static const u_int32_t kSpecialOp = (1 << 26);
static const uint32_t kSpecialOp = (1 << 26);
// special sequence number for local state checkpoint
static const u_int32_t kLocalCheckPoint = (1 << 26) - 2;
static const uint32_t kLocalCheckPoint = (1 << 26) - 2;
// special sequnce number for local state checkpoint ack signal
static const u_int32_t kLocalCheckAck = (1 << 26) - 1;
static const uint32_t kLocalCheckAck = (1 << 26) - 1;
//---------------------------------------------
// The following are bit mask of flag used in
//----------------------------------------------
@@ -242,13 +242,13 @@ class AllreduceRobust : public AllreduceBase {
ActionSummary() = default;
// constructor of action
explicit ActionSummary(int seqno_flag, int cache_flag = 0,
u_int32_t minseqno = kSpecialOp, u_int32_t maxseqno = kSpecialOp) {
uint32_t minseqno = kSpecialOp, uint32_t maxseqno = kSpecialOp) {
seqcode_ = (minseqno << 5) | seqno_flag;
maxseqcode_ = (maxseqno << 5) | cache_flag;
}
// minimum number of all operations by default
// maximum number of all cache operations otherwise
inline u_int32_t Seqno(SeqType t = SeqType::kSeq) const {
inline uint32_t Seqno(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
return code >> 5;
}
@@ -294,8 +294,8 @@ class AllreduceRobust : public AllreduceBase {
const ActionSummary *src = static_cast<const ActionSummary*>(src_);
ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_);
for (int i = 0; i < len; ++i) {
u_int32_t min_seqno = std::min(src[i].Seqno(), dst[i].Seqno());
u_int32_t max_seqno = std::max(src[i].Seqno(SeqType::kCache),
uint32_t min_seqno = Min(src[i].Seqno(), dst[i].Seqno());
uint32_t max_seqno = Max(src[i].Seqno(SeqType::kCache),
dst[i].Seqno(SeqType::kCache));
int action_flag = src[i].Flag() | dst[i].Flag();
// if any node is not requester set to 0 otherwise 1
@@ -310,9 +310,9 @@ class AllreduceRobust : public AllreduceBase {
private:
// internel sequence code min of rabit seqno
u_int32_t seqcode_;
uint32_t seqcode_;
// internal sequence code max of cache seqno
u_int32_t maxseqcode_;
uint32_t maxseqcode_;
};
/*! \brief data structure to remember result of Bcast and Allreduce calls*/
class ResultBuffer{

View File

@@ -1,7 +1,7 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine_mock.cc
* \brief this is an engine implementation that will
* \brief this is an engine implementation that will
* insert failures in certain call point, to test if the engine is robust to failure
* \author Tianqi Chen
*/
@@ -11,4 +11,3 @@
// switch engine to AllreduceMock
#define RABIT_USE_BASE
#include "engine.cc"

View File

@@ -1,132 +0,0 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine_empty.cc
* \brief this file provides a dummy implementation of engine that does nothing
* this file provides a way to fall back to single node program without causing too many dependencies
* This is usually NOT needed, use engine_mpi or engine for real distributed version
* \author Tianqi Chen
*/
#define NOMINMAX
#include <rabit/base.h>
#include "rabit/internal/engine.h"
namespace rabit {
namespace engine {
/*! \brief EmptyEngine */
class EmptyEngine : public IEngine {
public:
EmptyEngine() {
version_number_ = 0;
}
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
size_t slice_end, size_t size_prev_slice, const char *_file,
const int _line, const char *_caller) override {
utils::Error("EmptyEngine:: Allgather is not supported");
}
int GetRingPrevRank() const override {
utils::Error("EmptyEngine:: GetRingPrevRank is not supported");
return -1;
}
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
ReduceFunction reducer, PreprocFunction prepare_fun,
void *prepare_arg, const char *_file, const int _line,
const char *_caller) override {
utils::Error("EmptyEngine:: Allreduce is not supported,"\
"use Allreduce_ instead");
}
void Broadcast(void *sendrecvbuf_, size_t size, int root,
const char* _file, const int _line, const char* _caller) override {
}
void InitAfterException() override {
utils::Error("EmptyEngine is not fault tolerant");
}
int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = nullptr) override {
return 0;
}
void CheckPoint(const Serializable *global_model,
const Serializable *local_model = nullptr) override {
version_number_ += 1;
}
void LazyCheckPoint(const Serializable *global_model) override {
version_number_ += 1;
}
int VersionNumber() const override {
return version_number_;
}
/*! \brief get rank of current node */
int GetRank() const override {
return 0;
}
/*! \brief get total number of */
int GetWorldSize() const override {
return 1;
}
/*! \brief whether it is distributed */
bool IsDistributed() const override {
return false;
}
/*! \brief get the host name of current node */
std::string GetHost() const override {
return std::string("");
}
void TrackerPrint(const std::string &msg) override {
// simply print information into the tracker
utils::Printf("%s", msg.c_str());
}
private:
int version_number_;
};
// singleton sync manager
EmptyEngine manager;
/*! \brief intiialize the synchronization module */
bool Init(int argc, char *argv[]) {
return true;
}
/*! \brief finalize syncrhonization module */
bool Finalize() {
return true;
}
/*! \brief singleton method to get engine */
IEngine *GetEngine() {
return &manager;
}
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf,
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
mpi::DataType dtype,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg,
const char* _file,
const int _line,
const char* _caller) {
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
}
// code for reduce handle
ReduceHandle::ReduceHandle() = default;
ReduceHandle::~ReduceHandle() = default;
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return 0;
}
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {}
void ReduceHandle::Allreduce(void *sendrecvbuf,
size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg,
const char* _file,
const int _line,
const char* _caller) {
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
}
} // namespace engine
} // namespace rabit