Enable building rabit on Windows (#6105)
This commit is contained in:
@@ -8,7 +8,11 @@
|
||||
#define NOMINMAX
|
||||
#include "allreduce_base.h"
|
||||
#include <rabit/base.h>
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <netinet/tcp.h>
|
||||
#endif // _WIN32
|
||||
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
|
||||
@@ -413,8 +417,12 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
all_link.sock.SetNonBlock(true);
|
||||
all_link.sock.SetKeepAlive(true);
|
||||
if (rabit_enable_tcp_no_delay) {
|
||||
#if defined(__unix__)
|
||||
setsockopt(all_link.sock, IPPROTO_TCP,
|
||||
TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay));
|
||||
#else
|
||||
fprintf(stderr, "tcp no delay is not implemented on non unix platforms\n");
|
||||
#endif
|
||||
}
|
||||
if (tree_neighbors.count(all_link.rank) != 0) {
|
||||
if (all_link.rank == parent_rank) {
|
||||
|
||||
@@ -306,10 +306,11 @@ class AllreduceBase : public IEngine {
|
||||
// constructor
|
||||
LinkRecord() = default;
|
||||
// initialize buffer
|
||||
inline void InitBuffer(size_t type_nbytes, size_t count,
|
||||
size_t reduce_buffer_size) {
|
||||
void InitBuffer(size_t type_nbytes, size_t count,
|
||||
size_t reduce_buffer_size) {
|
||||
size_t n = (type_nbytes * count + 7)/ 8;
|
||||
buffer_.resize(std::min(reduce_buffer_size, n));
|
||||
auto to = Min(reduce_buffer_size, n);
|
||||
buffer_.resize(to);
|
||||
// make sure align to type_nbytes
|
||||
buffer_size =
|
||||
buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
|
||||
@@ -338,8 +339,8 @@ class AllreduceBase : public IEngine {
|
||||
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
|
||||
size_t offset = size_read % buffer_size;
|
||||
size_t nmax = max_size_read - size_read;
|
||||
nmax = std::min(nmax, buffer_size - ngap);
|
||||
nmax = std::min(nmax, buffer_size - offset);
|
||||
nmax = Min(nmax, buffer_size - ngap);
|
||||
nmax = Min(nmax, buffer_size - offset);
|
||||
if (nmax == 0) return kSuccess;
|
||||
ssize_t len = sock.Recv(buffer_head + offset, nmax);
|
||||
// length equals 0, remote disconnected
|
||||
|
||||
@@ -217,11 +217,11 @@ class AllreduceRobust : public AllreduceBase {
|
||||
*/
|
||||
struct ActionSummary {
|
||||
// maximumly allowed sequence id
|
||||
static const u_int32_t kSpecialOp = (1 << 26);
|
||||
static const uint32_t kSpecialOp = (1 << 26);
|
||||
// special sequence number for local state checkpoint
|
||||
static const u_int32_t kLocalCheckPoint = (1 << 26) - 2;
|
||||
static const uint32_t kLocalCheckPoint = (1 << 26) - 2;
|
||||
// special sequnce number for local state checkpoint ack signal
|
||||
static const u_int32_t kLocalCheckAck = (1 << 26) - 1;
|
||||
static const uint32_t kLocalCheckAck = (1 << 26) - 1;
|
||||
//---------------------------------------------
|
||||
// The following are bit mask of flag used in
|
||||
//----------------------------------------------
|
||||
@@ -242,13 +242,13 @@ class AllreduceRobust : public AllreduceBase {
|
||||
ActionSummary() = default;
|
||||
// constructor of action
|
||||
explicit ActionSummary(int seqno_flag, int cache_flag = 0,
|
||||
u_int32_t minseqno = kSpecialOp, u_int32_t maxseqno = kSpecialOp) {
|
||||
uint32_t minseqno = kSpecialOp, uint32_t maxseqno = kSpecialOp) {
|
||||
seqcode_ = (minseqno << 5) | seqno_flag;
|
||||
maxseqcode_ = (maxseqno << 5) | cache_flag;
|
||||
}
|
||||
// minimum number of all operations by default
|
||||
// maximum number of all cache operations otherwise
|
||||
inline u_int32_t Seqno(SeqType t = SeqType::kSeq) const {
|
||||
inline uint32_t Seqno(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
|
||||
return code >> 5;
|
||||
}
|
||||
@@ -294,8 +294,8 @@ class AllreduceRobust : public AllreduceBase {
|
||||
const ActionSummary *src = static_cast<const ActionSummary*>(src_);
|
||||
ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
u_int32_t min_seqno = std::min(src[i].Seqno(), dst[i].Seqno());
|
||||
u_int32_t max_seqno = std::max(src[i].Seqno(SeqType::kCache),
|
||||
uint32_t min_seqno = Min(src[i].Seqno(), dst[i].Seqno());
|
||||
uint32_t max_seqno = Max(src[i].Seqno(SeqType::kCache),
|
||||
dst[i].Seqno(SeqType::kCache));
|
||||
int action_flag = src[i].Flag() | dst[i].Flag();
|
||||
// if any node is not requester set to 0 otherwise 1
|
||||
@@ -310,9 +310,9 @@ class AllreduceRobust : public AllreduceBase {
|
||||
|
||||
private:
|
||||
// internel sequence code min of rabit seqno
|
||||
u_int32_t seqcode_;
|
||||
uint32_t seqcode_;
|
||||
// internal sequence code max of cache seqno
|
||||
u_int32_t maxseqcode_;
|
||||
uint32_t maxseqcode_;
|
||||
};
|
||||
/*! \brief data structure to remember result of Bcast and Allreduce calls*/
|
||||
class ResultBuffer{
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file engine_mock.cc
|
||||
* \brief this is an engine implementation that will
|
||||
* \brief this is an engine implementation that will
|
||||
* insert failures in certain call point, to test if the engine is robust to failure
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
@@ -11,4 +11,3 @@
|
||||
// switch engine to AllreduceMock
|
||||
#define RABIT_USE_BASE
|
||||
#include "engine.cc"
|
||||
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file engine_empty.cc
|
||||
* \brief this file provides a dummy implementation of engine that does nothing
|
||||
* this file provides a way to fall back to single node program without causing too many dependencies
|
||||
* This is usually NOT needed, use engine_mpi or engine for real distributed version
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#define NOMINMAX
|
||||
|
||||
#include <rabit/base.h>
|
||||
#include "rabit/internal/engine.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
/*! \brief EmptyEngine */
|
||||
class EmptyEngine : public IEngine {
|
||||
public:
|
||||
EmptyEngine() {
|
||||
version_number_ = 0;
|
||||
}
|
||||
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
|
||||
size_t slice_end, size_t size_prev_slice, const char *_file,
|
||||
const int _line, const char *_caller) override {
|
||||
utils::Error("EmptyEngine:: Allgather is not supported");
|
||||
}
|
||||
int GetRingPrevRank() const override {
|
||||
utils::Error("EmptyEngine:: GetRingPrevRank is not supported");
|
||||
return -1;
|
||||
}
|
||||
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
|
||||
ReduceFunction reducer, PreprocFunction prepare_fun,
|
||||
void *prepare_arg, const char *_file, const int _line,
|
||||
const char *_caller) override {
|
||||
utils::Error("EmptyEngine:: Allreduce is not supported,"\
|
||||
"use Allreduce_ instead");
|
||||
}
|
||||
void Broadcast(void *sendrecvbuf_, size_t size, int root,
|
||||
const char* _file, const int _line, const char* _caller) override {
|
||||
}
|
||||
void InitAfterException() override {
|
||||
utils::Error("EmptyEngine is not fault tolerant");
|
||||
}
|
||||
int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model = nullptr) override {
|
||||
return 0;
|
||||
}
|
||||
void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model = nullptr) override {
|
||||
version_number_ += 1;
|
||||
}
|
||||
void LazyCheckPoint(const Serializable *global_model) override {
|
||||
version_number_ += 1;
|
||||
}
|
||||
int VersionNumber() const override {
|
||||
return version_number_;
|
||||
}
|
||||
/*! \brief get rank of current node */
|
||||
int GetRank() const override {
|
||||
return 0;
|
||||
}
|
||||
/*! \brief get total number of */
|
||||
int GetWorldSize() const override {
|
||||
return 1;
|
||||
}
|
||||
/*! \brief whether it is distributed */
|
||||
bool IsDistributed() const override {
|
||||
return false;
|
||||
}
|
||||
/*! \brief get the host name of current node */
|
||||
std::string GetHost() const override {
|
||||
return std::string("");
|
||||
}
|
||||
void TrackerPrint(const std::string &msg) override {
|
||||
// simply print information into the tracker
|
||||
utils::Printf("%s", msg.c_str());
|
||||
}
|
||||
|
||||
private:
|
||||
int version_number_;
|
||||
};
|
||||
|
||||
// singleton sync manager
|
||||
EmptyEngine manager;
|
||||
|
||||
/*! \brief intiialize the synchronization module */
|
||||
bool Init(int argc, char *argv[]) {
|
||||
return true;
|
||||
}
|
||||
/*! \brief finalize syncrhonization module */
|
||||
bool Finalize() {
|
||||
return true;
|
||||
}
|
||||
|
||||
/*! \brief singleton method to get engine */
|
||||
IEngine *GetEngine() {
|
||||
return &manager;
|
||||
}
|
||||
// perform in-place allreduce, on sendrecvbuf
|
||||
void Allreduce_(void *sendrecvbuf,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
IEngine::ReduceFunction red,
|
||||
mpi::DataType dtype,
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
|
||||
}
|
||||
|
||||
// code for reduce handle
|
||||
ReduceHandle::ReduceHandle() = default;
|
||||
ReduceHandle::~ReduceHandle() = default;
|
||||
|
||||
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
||||
return 0;
|
||||
}
|
||||
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {}
|
||||
void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||
size_t type_nbytes, size_t count,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
Reference in New Issue
Block a user