Correct style warnings from clang-tidy for rabit. (#6095)

This commit is contained in:
Jiaming Yuan 2020-09-08 12:13:58 +08:00 committed by GitHub
parent da61d9460b
commit b0001a6e29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 736 additions and 868 deletions

View File

@ -19,7 +19,7 @@
#define _CALLER "N/A" #define _CALLER "N/A"
#endif // (defined(__GNUC__) && !defined(__clang__)) #endif // (defined(__GNUC__) && !defined(__clang__))
namespace MPI { namespace MPI { // NOLINT
/*! \brief MPI data type just to be compatible with MPI reduce function*/ /*! \brief MPI data type just to be compatible with MPI reduce function*/
class Datatype; class Datatype;
} }
@ -36,7 +36,7 @@ class IEngine {
* used to prepare the data used by AllReduce * used to prepare the data used by AllReduce
* \param arg additional possible argument used to invoke the preprocessor * \param arg additional possible argument used to invoke the preprocessor
*/ */
typedef void (PreprocFunction) (void *arg); typedef void (PreprocFunction) (void *arg); // NOLINT
/*! /*!
* \brief reduce function, the same form of MPI reduce function is used, * \brief reduce function, the same form of MPI reduce function is used,
* to be compatible with MPI interface * to be compatible with MPI interface
@ -48,11 +48,11 @@ class IEngine {
* the definition of the reduce function should be type aware * the definition of the reduce function should be type aware
* \param dtype the data type object, to be compatible with MPI reduce * \param dtype the data type object, to be compatible with MPI reduce
*/ */
typedef void (ReduceFunction) (const void *src, typedef void (ReduceFunction) (const void *src, // NOLINT
void *dst, int count, void *dst, int count,
const MPI::Datatype &dtype); const MPI::Datatype &dtype);
/*! \brief virtual destructor */ /*! \brief virtual destructor */
virtual ~IEngine() {} ~IEngine() = default;
/*! /*!
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf, * \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end), * the data provided by current node k is [slice_begin, slice_end),
@ -68,7 +68,7 @@ class IEngine {
* \param _file caller file name used to generate unique cache key * \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key * \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key * \param _caller caller function name used to generate unique cache key
*/ */
virtual void Allgather(void *sendrecvbuf, virtual void Allgather(void *sendrecvbuf,
size_t total_size, size_t total_size,
size_t slice_begin, size_t slice_begin,
@ -96,8 +96,8 @@ class IEngine {
size_t type_nbytes, size_t type_nbytes,
size_t count, size_t count,
ReduceFunction reducer, ReduceFunction reducer,
PreprocFunction prepare_fun = NULL, PreprocFunction prepare_fun = nullptr,
void *prepare_arg = NULL, void *prepare_arg = nullptr,
const char* _file = _FILE, const char* _file = _FILE,
const int _line = _LINE, const int _line = _LINE,
const char* _caller = _CALLER) = 0; const char* _caller = _CALLER) = 0;
@ -119,7 +119,7 @@ class IEngine {
* call this function when IEngine throws an exception, * call this function when IEngine throws an exception,
* this function should only be used for test purposes * this function should only be used for test purposes
*/ */
virtual void InitAfterException(void) = 0; virtual void InitAfterException() = 0;
/*! /*!
* \brief loads the latest check point * \brief loads the latest check point
* \param global_model pointer to the globally shared model/state * \param global_model pointer to the globally shared model/state
@ -143,7 +143,7 @@ class IEngine {
* \sa CheckPoint, VersionNumber * \sa CheckPoint, VersionNumber
*/ */
virtual int LoadCheckPoint(Serializable *global_model, virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL) = 0; Serializable *local_model = nullptr) = 0;
/*! /*!
* \brief checkpoints the model, meaning a stage of execution was finished * \brief checkpoints the model, meaning a stage of execution was finished
* every time we call check point, a version number increases by ones * every time we call check point, a version number increases by ones
@ -161,7 +161,7 @@ class IEngine {
* \sa LoadCheckPoint, VersionNumber * \sa LoadCheckPoint, VersionNumber
*/ */
virtual void CheckPoint(const Serializable *global_model, virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL) = 0; const Serializable *local_model = nullptr) = 0;
/*! /*!
* \brief This function can be used to replace CheckPoint for global_model only, * \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met (see detailed explanation). * when certain condition is met (see detailed explanation).
@ -188,17 +188,17 @@ class IEngine {
* which means how many calls to CheckPoint we made so far * which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint * \sa LoadCheckPoint, CheckPoint
*/ */
virtual int VersionNumber(void) const = 0; virtual int VersionNumber() const = 0;
/*! \brief gets rank of previous node in ring topology */ /*! \brief gets rank of previous node in ring topology */
virtual int GetRingPrevRank(void) const = 0; virtual int GetRingPrevRank() const = 0;
/*! \brief gets rank of current node */ /*! \brief gets rank of current node */
virtual int GetRank(void) const = 0; virtual int GetRank() const = 0;
/*! \brief gets total number of nodes */ /*! \brief gets total number of nodes */
virtual int GetWorldSize(void) const = 0; virtual int GetWorldSize() const = 0;
/*! \brief whether we run in distribted mode */ /*! \brief whether we run in distribted mode */
virtual bool IsDistributed(void) const = 0; virtual bool IsDistributed() const = 0;
/*! \brief gets the host name of the current node */ /*! \brief gets the host name of the current node */
virtual std::string GetHost(void) const = 0; virtual std::string GetHost() const = 0;
/*! /*!
* \brief prints the msg in the tracker, * \brief prints the msg in the tracker,
* this function can be used to communicate progress information to * this function can be used to communicate progress information to
@ -211,9 +211,9 @@ class IEngine {
/*! \brief initializes the engine module */ /*! \brief initializes the engine module */
bool Init(int argc, char *argv[]); bool Init(int argc, char *argv[]);
/*! \brief finalizes the engine module */ /*! \brief finalizes the engine module */
bool Finalize(void); bool Finalize();
/*! \brief singleton method to get engine */ /*! \brief singleton method to get engine */
IEngine *GetEngine(void); IEngine *GetEngine();
/*! \brief namespace that contains stubs to be compatible with MPI */ /*! \brief namespace that contains stubs to be compatible with MPI */
namespace mpi { namespace mpi {
@ -280,14 +280,14 @@ void Allgather(void* sendrecvbuf,
* \param _line caller line number used to generate unique cache key * \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key * \param _caller caller function name used to generate unique cache key
*/ */
void Allreduce_(void *sendrecvbuf, void Allreduce_(void *sendrecvbuf, // NOLINT
size_t type_nbytes, size_t type_nbytes,
size_t count, size_t count,
IEngine::ReduceFunction red, IEngine::ReduceFunction red,
mpi::DataType dtype, mpi::DataType dtype,
mpi::OpType op, mpi::OpType op,
IEngine::PreprocFunction prepare_fun = NULL, IEngine::PreprocFunction prepare_fun = nullptr,
void *prepare_arg = NULL, void *prepare_arg = nullptr,
const char* _file = _FILE, const char* _file = _FILE,
const int _line = _LINE, const int _line = _LINE,
const char* _caller = _CALLER); const char* _caller = _CALLER);
@ -298,9 +298,9 @@ void Allreduce_(void *sendrecvbuf,
class ReduceHandle { class ReduceHandle {
public: public:
// constructor // constructor
ReduceHandle(void); ReduceHandle();
// destructor // destructor
~ReduceHandle(void); ~ReduceHandle();
/*! /*!
* \brief initialize the reduce function, * \brief initialize the reduce function,
* with the type the reduce function needs to deal with * with the type the reduce function needs to deal with
@ -323,8 +323,8 @@ class ReduceHandle {
void Allreduce(void *sendrecvbuf, void Allreduce(void *sendrecvbuf,
size_t type_nbytes, size_t type_nbytes,
size_t count, size_t count,
IEngine::PreprocFunction prepare_fun = NULL, IEngine::PreprocFunction prepare_fun = nullptr,
void *prepare_arg = NULL, void *prepare_arg = nullptr,
const char* _file = _FILE, const char* _file = _FILE,
const int _line = _LINE, const int _line = _LINE,
const char* _caller = _CALLER); const char* _caller = _CALLER);
@ -333,11 +333,11 @@ class ReduceHandle {
protected: protected:
// handle function field // handle function field
void *handle_; void *handle_ {nullptr};
// reduce function of the reducer // reduce function of the reducer
IEngine::ReduceFunction *redfunc_; IEngine::ReduceFunction *redfunc_{nullptr};
// handle to the type field // handle to the type field
void *htype_; void *htype_{nullptr};
// the created type in 4 bytes // the created type in 4 bytes
size_t created_type_nbytes_; size_t created_type_nbytes_;
}; };

View File

@ -19,12 +19,12 @@
namespace rabit { namespace rabit {
namespace utils { namespace utils {
/*! \brief re-use definition of dmlc::SeekStream */ /*! \brief re-use definition of dmlc::SeekStream */
typedef dmlc::SeekStream SeekStream; using SeekStream = dmlc::SeekStream;
/*! \brief fixed size memory buffer */ /*! \brief fixed size memory buffer */
struct MemoryFixSizeBuffer : public SeekStream { struct MemoryFixSizeBuffer : public SeekStream {
public: public:
// similar to SEEK_END in libc // similar to SEEK_END in libc
static size_t constexpr SeekEnd = std::numeric_limits<size_t>::max(); static size_t constexpr kSeekEnd = std::numeric_limits<size_t>::max();
public: public:
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size) MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
@ -32,31 +32,31 @@ struct MemoryFixSizeBuffer : public SeekStream {
buffer_size_(buffer_size) { buffer_size_(buffer_size) {
curr_ptr_ = 0; curr_ptr_ = 0;
} }
virtual ~MemoryFixSizeBuffer(void) {} ~MemoryFixSizeBuffer() override = default;
virtual size_t Read(void *ptr, size_t size) { size_t Read(void *ptr, size_t size) override {
size_t nread = std::min(buffer_size_ - curr_ptr_, size); size_t nread = std::min(buffer_size_ - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread); if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
curr_ptr_ += nread; curr_ptr_ += nread;
return nread; return nread;
} }
virtual void Write(const void *ptr, size_t size) { void Write(const void *ptr, size_t size) override {
if (size == 0) return; if (size == 0) return;
utils::Assert(curr_ptr_ + size <= buffer_size_, utils::Assert(curr_ptr_ + size <= buffer_size_,
"write position exceed fixed buffer size"); "write position exceed fixed buffer size");
std::memcpy(p_buffer_ + curr_ptr_, ptr, size); std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
curr_ptr_ += size; curr_ptr_ += size;
} }
virtual void Seek(size_t pos) { void Seek(size_t pos) override {
if (pos == SeekEnd) { if (pos == kSeekEnd) {
curr_ptr_ = buffer_size_; curr_ptr_ = buffer_size_;
} else { } else {
curr_ptr_ = static_cast<size_t>(pos); curr_ptr_ = static_cast<size_t>(pos);
} }
} }
virtual size_t Tell(void) { size_t Tell() override {
return curr_ptr_; return curr_ptr_;
} }
virtual bool AtEnd(void) const { virtual bool AtEnd() const {
return curr_ptr_ == buffer_size_; return curr_ptr_ == buffer_size_;
} }
@ -76,8 +76,8 @@ struct MemoryBufferStream : public SeekStream {
: p_buffer_(p_buffer) { : p_buffer_(p_buffer) {
curr_ptr_ = 0; curr_ptr_ = 0;
} }
virtual ~MemoryBufferStream(void) {} ~MemoryBufferStream() override = default;
virtual size_t Read(void *ptr, size_t size) { size_t Read(void *ptr, size_t size) override {
utils::Assert(curr_ptr_ <= p_buffer_->length(), utils::Assert(curr_ptr_ <= p_buffer_->length(),
"read can not have position excceed buffer length"); "read can not have position excceed buffer length");
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size); size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
@ -85,7 +85,7 @@ struct MemoryBufferStream : public SeekStream {
curr_ptr_ += nread; curr_ptr_ += nread;
return nread; return nread;
} }
virtual void Write(const void *ptr, size_t size) { void Write(const void *ptr, size_t size) override {
if (size == 0) return; if (size == 0) return;
if (curr_ptr_ + size > p_buffer_->length()) { if (curr_ptr_ + size > p_buffer_->length()) {
p_buffer_->resize(curr_ptr_+size); p_buffer_->resize(curr_ptr_+size);
@ -93,13 +93,13 @@ struct MemoryBufferStream : public SeekStream {
std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size); std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
curr_ptr_ += size; curr_ptr_ += size;
} }
virtual void Seek(size_t pos) { void Seek(size_t pos) override {
curr_ptr_ = static_cast<size_t>(pos); curr_ptr_ = static_cast<size_t>(pos);
} }
virtual size_t Tell(void) { size_t Tell() override {
return curr_ptr_; return curr_ptr_;
} }
virtual bool AtEnd(void) const { virtual bool AtEnd() const {
return curr_ptr_ == p_buffer_->length(); return curr_ptr_ == p_buffer_->length();
} }

View File

@ -19,45 +19,45 @@ namespace engine {
namespace mpi { namespace mpi {
// template function to translate type to enum indicator // template function to translate type to enum indicator
template<typename DType> template<typename DType>
inline DataType GetType(void); inline DataType GetType();
template<> template<>
inline DataType GetType<char>(void) { inline DataType GetType<char>() {
return kChar; return kChar;
} }
template<> template<>
inline DataType GetType<unsigned char>(void) { inline DataType GetType<unsigned char>() {
return kUChar; return kUChar;
} }
template<> template<>
inline DataType GetType<int>(void) { inline DataType GetType<int>() {
return kInt; return kInt;
} }
template<> template<>
inline DataType GetType<unsigned int>(void) { // NOLINT(*) inline DataType GetType<unsigned int>() { // NOLINT(*)
return kUInt; return kUInt;
} }
template<> template<>
inline DataType GetType<long>(void) { // NOLINT(*) inline DataType GetType<long>() { // NOLINT(*)
return kLong; return kLong;
} }
template<> template<>
inline DataType GetType<unsigned long>(void) { // NOLINT(*) inline DataType GetType<unsigned long>() { // NOLINT(*)
return kULong; return kULong;
} }
template<> template<>
inline DataType GetType<float>(void) { inline DataType GetType<float>() {
return kFloat; return kFloat;
} }
template<> template<>
inline DataType GetType<double>(void) { inline DataType GetType<double>() {
return kDouble; return kDouble;
} }
template<> template<>
inline DataType GetType<long long>(void) { // NOLINT(*) inline DataType GetType<long long>() { // NOLINT(*)
return kLongLong; return kLongLong;
} }
template<> template<>
inline DataType GetType<unsigned long long>(void) { // NOLINT(*) inline DataType GetType<unsigned long long>() { // NOLINT(*)
return kULongLong; return kULongLong;
} }
} // namespace mpi } // namespace mpi
@ -94,7 +94,7 @@ struct BitOR {
}; };
template<typename OP, typename DType> template<typename OP, typename DType>
inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) { inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
const DType* src = (const DType*)src_; const DType* src = static_cast<const DType*>(src_);
DType* dst = (DType*)dst_; // NOLINT(*) DType* dst = (DType*)dst_; // NOLINT(*)
for (int i = 0; i < len; i++) { for (int i = 0; i < len; i++) {
OP::Reduce(dst[i], src[i]); OP::Reduce(dst[i], src[i]);
@ -107,27 +107,27 @@ inline bool Init(int argc, char *argv[]) {
return engine::Init(argc, argv); return engine::Init(argc, argv);
} }
// finalize the rabit engine // finalize the rabit engine
inline bool Finalize(void) { inline bool Finalize() {
return engine::Finalize(); return engine::Finalize();
} }
// get the rank of the previous worker in ring topology // get the rank of the previous worker in ring topology
inline int GetRingPrevRank(void) { inline int GetRingPrevRank() {
return engine::GetEngine()->GetRingPrevRank(); return engine::GetEngine()->GetRingPrevRank();
} }
// get the rank of current process // get the rank of current process
inline int GetRank(void) { inline int GetRank() {
return engine::GetEngine()->GetRank(); return engine::GetEngine()->GetRank();
} }
// the the size of the world // the the size of the world
inline int GetWorldSize(void) { inline int GetWorldSize() {
return engine::GetEngine()->GetWorldSize(); return engine::GetEngine()->GetWorldSize();
} }
// whether rabit is distributed // whether rabit is distributed
inline bool IsDistributed(void) { inline bool IsDistributed() {
return engine::GetEngine()->IsDistributed(); return engine::GetEngine()->IsDistributed();
} }
// get the name of current processor // get the name of current processor
inline std::string GetProcessorName(void) { inline std::string GetProcessorName() {
return engine::GetEngine()->GetHost(); return engine::GetEngine()->GetHost();
} }
// broadcast data to all other nodes from root // broadcast data to all other nodes from root
@ -183,7 +183,7 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
// C++11 support for lambda prepare function // C++11 support for lambda prepare function
#if DMLC_USE_CXX11 #if DMLC_USE_CXX11
inline void InvokeLambda_(void *fun) { inline void InvokeLambda(void *fun) {
(*static_cast<std::function<void()>*>(fun))(); (*static_cast<std::function<void()>*>(fun))();
} }
template<typename OP, typename DType> template<typename OP, typename DType>
@ -193,7 +193,7 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
const int _line, const int _line,
const char* _caller) { const char* _caller) {
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>, engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda_, &prepare_fun, engine::mpi::GetType<DType>(), OP::kType, InvokeLambda, &prepare_fun,
_file, _line, _caller); _file, _line, _caller);
} }
@ -245,7 +245,7 @@ inline void LazyCheckPoint(const Serializable *global_model) {
engine::GetEngine()->LazyCheckPoint(global_model); engine::GetEngine()->LazyCheckPoint(global_model);
} }
// return the version number of currently stored model // return the version number of currently stored model
inline int VersionNumber(void) { inline int VersionNumber() {
return engine::GetEngine()->VersionNumber(); return engine::GetEngine()->VersionNumber();
} }
// --------------------------------- // ---------------------------------
@ -253,7 +253,7 @@ inline int VersionNumber(void) {
// --------------------------------- // ---------------------------------
// function to perform reduction for Reducer // function to perform reduction for Reducer
template<typename DType, void (*freduce)(DType &dst, const DType &src)> template<typename DType, void (*freduce)(DType &dst, const DType &src)>
inline void ReducerSafe_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) { inline void ReducerSafeImpl(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
const size_t kUnit = sizeof(DType); const size_t kUnit = sizeof(DType);
const char *psrc = reinterpret_cast<const char*>(src_); const char *psrc = reinterpret_cast<const char*>(src_);
char *pdst = reinterpret_cast<char*>(dst_); char *pdst = reinterpret_cast<char*>(dst_);
@ -269,7 +269,7 @@ inline void ReducerSafe_(const void *src_, void *dst_, int len_, const MPI::Data
} }
// function to perform reduction for Reducer // function to perform reduction for Reducer
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*) template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
inline void ReducerAlign_(const void *src_, void *dst_, inline void ReducerAlignImpl(const void *src_, void *dst_,
int len_, const MPI::Datatype &dtype) { int len_, const MPI::Datatype &dtype) {
const DType *psrc = reinterpret_cast<const DType*>(src_); const DType *psrc = reinterpret_cast<const DType*>(src_);
DType *pdst = reinterpret_cast<DType*>(dst_); DType *pdst = reinterpret_cast<DType*>(dst_);
@ -278,12 +278,12 @@ inline void ReducerAlign_(const void *src_, void *dst_,
} }
} }
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*) template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
inline Reducer<DType, freduce>::Reducer(void) { inline Reducer<DType, freduce>::Reducer() {
// it is safe to directly use handle for aligned data types // it is safe to directly use handle for aligned data types
if (sizeof(DType) == 8 || sizeof(DType) == 4 || sizeof(DType) == 1) { if (sizeof(DType) == 8 || sizeof(DType) == 4 || sizeof(DType) == 1) {
this->handle_.Init(ReducerAlign_<DType, freduce>, sizeof(DType)); this->handle_.Init(ReducerAlignImpl<DType, freduce>, sizeof(DType));
} else { } else {
this->handle_.Init(ReducerSafe_<DType, freduce>, sizeof(DType)); this->handle_.Init(ReducerSafeImpl<DType, freduce>, sizeof(DType));
} }
} }
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*) template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
@ -298,8 +298,8 @@ inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
} }
// function to perform reduction for SerializeReducer // function to perform reduction for SerializeReducer
template<typename DType> template<typename DType>
inline void SerializeReducerFunc_(const void *src_, void *dst_, inline void SerializeReducerFuncImpl(const void *src_, void *dst_,
int len_, const MPI::Datatype &dtype) { int len_, const MPI::Datatype &dtype) {
int nbytes = engine::ReduceHandle::TypeSize(dtype); int nbytes = engine::ReduceHandle::TypeSize(dtype);
// temp space // temp space
for (int i = 0; i < len_; ++i) { for (int i = 0; i < len_; ++i) {
@ -315,8 +315,8 @@ inline void SerializeReducerFunc_(const void *src_, void *dst_,
} }
} }
template<typename DType> template<typename DType>
inline SerializeReducer<DType>::SerializeReducer(void) { inline SerializeReducer<DType>::SerializeReducer() {
handle_.Init(SerializeReducerFunc_<DType>, sizeof(DType)); handle_.Init(SerializeReducerFuncImpl<DType>, sizeof(DType));
} }
// closure to call Allreduce // closure to call Allreduce
template<typename DType> template<typename DType>
@ -327,8 +327,8 @@ struct SerializeReduceClosure {
void *prepare_arg; void *prepare_arg;
std::string *p_buffer; std::string *p_buffer;
// invoke the closure // invoke the closure
inline void Run(void) { inline void Run() {
if (prepare_fun != NULL) prepare_fun(prepare_arg); if (prepare_fun != nullptr) prepare_fun(prepare_arg);
for (size_t i = 0; i < count; ++i) { for (size_t i = 0; i < count; ++i) {
utils::MemoryFixSizeBuffer fs(BeginPtr(*p_buffer) + i * max_nbyte, max_nbyte); utils::MemoryFixSizeBuffer fs(BeginPtr(*p_buffer) + i * max_nbyte, max_nbyte);
sendrecvobj[i].Save(fs); sendrecvobj[i].Save(fs);
@ -368,7 +368,7 @@ inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
const char* _file, const char* _file,
const int _line, const int _line,
const char* _caller) { const char* _caller) {
this->Allreduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun, this->Allreduce(sendrecvbuf, count, InvokeLambda, &prepare_fun,
_file, _line, _caller); _file, _line, _caller);
} }
template<typename DType> template<typename DType>
@ -378,7 +378,7 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
const char* _file, const char* _file,
const int _line, const int _line,
const char* _caller) { const char* _caller) {
this->Allreduce(sendrecvobj, max_nbytes, count, InvokeLambda_, &prepare_fun, this->Allreduce(sendrecvobj, max_nbytes, count, InvokeLambda, &prepare_fun,
_file, _line, _caller); _file, _line, _caller);
} }
#endif // DMLC_USE_CXX11 #endif // DMLC_USE_CXX11

View File

@ -15,7 +15,7 @@
#else #else
#include <fcntl.h> #include <fcntl.h>
#include <netdb.h> #include <netdb.h>
#include <errno.h> #include <cerrno>
#include <unistd.h> #include <unistd.h>
#include <arpa/inet.h> #include <arpa/inet.h>
#include <netinet/in.h> #include <netinet/in.h>
@ -39,9 +39,9 @@ static inline int poll(struct pollfd *pfd, int nfds,
int timeout) { return WSAPoll ( pfd, nfds, timeout ); } int timeout) { return WSAPoll ( pfd, nfds, timeout ); }
#else #else
#include <sys/poll.h> #include <sys/poll.h>
typedef int SOCKET; using SOCKET = int;
typedef size_t sock_size_t; using sock_size_t = size_t; // NOLINT
const int INVALID_SOCKET = -1; const int kInvalidSocket = -1;
#endif // defined(_WIN32) #endif // defined(_WIN32)
namespace rabit { namespace rabit {
@ -50,11 +50,11 @@ namespace utils {
struct SockAddr { struct SockAddr {
sockaddr_in addr; sockaddr_in addr;
// constructor // constructor
SockAddr(void) {} SockAddr() = default;
SockAddr(const char *url, int port) { SockAddr(const char *url, int port) {
this->Set(url, port); this->Set(url, port);
} }
inline static std::string GetHostName(void) { inline static std::string GetHostName() {
std::string buf; buf.resize(256); std::string buf; buf.resize(256);
utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name"); utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
return std::string(buf.c_str()); return std::string(buf.c_str());
@ -69,20 +69,20 @@ struct SockAddr {
memset(&hints, 0, sizeof(hints)); memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET; hints.ai_family = AF_INET;
hints.ai_protocol = SOCK_STREAM; hints.ai_protocol = SOCK_STREAM;
addrinfo *res = NULL; addrinfo *res = nullptr;
int sig = getaddrinfo(host, NULL, &hints, &res); int sig = getaddrinfo(host, nullptr, &hints, &res);
Check(sig == 0 && res != NULL, "cannot obtain address of %s", host); Check(sig == 0 && res != nullptr, "cannot obtain address of %s", host);
Check(res->ai_family == AF_INET, "Does not support IPv6"); Check(res->ai_family == AF_INET, "Does not support IPv6");
memcpy(&addr, res->ai_addr, res->ai_addrlen); memcpy(&addr, res->ai_addr, res->ai_addrlen);
addr.sin_port = htons(port); addr.sin_port = htons(port);
freeaddrinfo(res); freeaddrinfo(res);
} }
/*! \brief return port of the address*/ /*! \brief return port of the address*/
inline int port(void) const { inline int Port() const {
return ntohs(addr.sin_port); return ntohs(addr.sin_port);
} }
/*! \return a string representation of the address */ /*! \return a string representation of the address */
inline std::string AddrStr(void) const { inline std::string AddrStr() const {
std::string buf; buf.resize(256); std::string buf; buf.resize(256);
#ifdef _WIN32 #ifdef _WIN32
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr, const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
@ -91,7 +91,7 @@ struct SockAddr {
const char *s = inet_ntop(AF_INET, &addr.sin_addr, const char *s = inet_ntop(AF_INET, &addr.sin_addr,
&buf[0], buf.length()); &buf[0], buf.length());
#endif // _WIN32 #endif // _WIN32
Assert(s != NULL, "cannot decode address"); Assert(s != nullptr, "cannot decode address");
return std::string(s); return std::string(s);
} }
}; };
@ -104,13 +104,13 @@ class Socket {
/*! \brief the file descriptor of socket */ /*! \brief the file descriptor of socket */
SOCKET sockfd; SOCKET sockfd;
// default conversion to int // default conversion to int
inline operator SOCKET() const { operator SOCKET() const { // NOLINT
return sockfd; return sockfd;
} }
/*! /*!
* \return last error of socket operation * \return last error of socket operation
*/ */
inline static int GetLastError(void) { inline static int GetLastError() {
#ifdef _WIN32 #ifdef _WIN32
return WSAGetLastError(); return WSAGetLastError();
#else #else
@ -118,7 +118,7 @@ class Socket {
#endif // _WIN32 #endif // _WIN32
} }
/*! \return whether last error was would block */ /*! \return whether last error was would block */
inline static bool LastErrorWouldBlock(void) { inline static bool LastErrorWouldBlock() {
int errsv = GetLastError(); int errsv = GetLastError();
#ifdef _WIN32 #ifdef _WIN32
return errsv == WSAEWOULDBLOCK; return errsv == WSAEWOULDBLOCK;
@ -130,7 +130,7 @@ class Socket {
* \brief start up the socket module * \brief start up the socket module
* call this before using the sockets * call this before using the sockets
*/ */
inline static void Startup(void) { inline static void Startup() {
#ifdef _WIN32 #ifdef _WIN32
WSADATA wsa_data; WSADATA wsa_data;
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) { if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
@ -145,7 +145,7 @@ class Socket {
/*! /*!
* \brief shutdown the socket module after use, all sockets need to be closed * \brief shutdown the socket module after use, all sockets need to be closed
*/ */
inline static void Finalize(void) { inline static void Finalize() {
#ifdef _WIN32 #ifdef _WIN32
WSACleanup(); WSACleanup();
#endif // _WIN32 #endif // _WIN32
@ -214,7 +214,7 @@ class Socket {
return -1; return -1;
} }
/*! \brief get last error code if any */ /*! \brief get last error code if any */
inline int GetSockError(void) const { inline int GetSockError() const {
int error = 0; int error = 0;
socklen_t len = sizeof(error); socklen_t len = sizeof(error);
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
@ -224,25 +224,25 @@ class Socket {
return error; return error;
} }
/*! \brief check if anything bad happens */ /*! \brief check if anything bad happens */
inline bool BadSocket(void) const { inline bool BadSocket() const {
if (IsClosed()) return true; if (IsClosed()) return true;
int err = GetSockError(); int err = GetSockError();
if (err == EBADF || err == EINTR) return true; if (err == EBADF || err == EINTR) return true;
return false; return false;
} }
/*! \brief check if socket is already closed */ /*! \brief check if socket is already closed */
inline bool IsClosed(void) const { inline bool IsClosed() const {
return sockfd == INVALID_SOCKET; return sockfd == kInvalidSocket;
} }
/*! \brief close the socket */ /*! \brief close the socket */
inline void Close(void) { inline void Close() {
if (sockfd != INVALID_SOCKET) { if (sockfd != kInvalidSocket) {
#ifdef _WIN32 #ifdef _WIN32
closesocket(sockfd); closesocket(sockfd);
#else #else
close(sockfd); close(sockfd);
#endif #endif
sockfd = INVALID_SOCKET; sockfd = kInvalidSocket;
} else { } else {
Error("Socket::Close double close the socket or close without create"); Error("Socket::Close double close the socket or close without create");
} }
@ -268,7 +268,7 @@ class Socket {
class TCPSocket : public Socket{ class TCPSocket : public Socket{
public: public:
// constructor // constructor
TCPSocket(void) : Socket(INVALID_SOCKET) { TCPSocket() : Socket(kInvalidSocket) {
} }
explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) { explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) {
} }
@ -297,7 +297,7 @@ class TCPSocket : public Socket{
*/ */
inline void Create(int af = PF_INET) { inline void Create(int af = PF_INET) {
sockfd = socket(PF_INET, SOCK_STREAM, 0); sockfd = socket(PF_INET, SOCK_STREAM, 0);
if (sockfd == INVALID_SOCKET) { if (sockfd == kInvalidSocket) {
Socket::Error("Create"); Socket::Error("Create");
} }
} }
@ -309,9 +309,9 @@ class TCPSocket : public Socket{
listen(sockfd, backlog); listen(sockfd, backlog);
} }
/*! \brief get a new connection */ /*! \brief get a new connection */
TCPSocket Accept(void) { TCPSocket Accept() {
SOCKET newfd = accept(sockfd, NULL, NULL); SOCKET newfd = accept(sockfd, nullptr, nullptr);
if (newfd == INVALID_SOCKET) { if (newfd == kInvalidSocket) {
Socket::Error("Accept"); Socket::Error("Accept");
} }
return TCPSocket(newfd); return TCPSocket(newfd);
@ -320,7 +320,7 @@ class TCPSocket : public Socket{
* \brief decide whether the socket is at OOB mark * \brief decide whether the socket is at OOB mark
* \return 1 if at mark, 0 if not, -1 if an error occured * \return 1 if at mark, 0 if not, -1 if an error occured
*/ */
inline int AtMark(void) const { inline int AtMark() const {
#ifdef _WIN32 #ifdef _WIN32
unsigned long atmark; // NOLINT(*) unsigned long atmark; // NOLINT(*)
if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1; if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;

View File

@ -1,87 +0,0 @@
/*!
* Copyright (c) 2015 by Contributors
* \file thread_local.h
* \brief Common utility for thread local storage.
*/
#ifndef RABIT_INTERNAL_THREAD_LOCAL_H_
#define RABIT_INTERNAL_THREAD_LOCAL_H_
#include "../include/dmlc/base.h"
#if DMLC_ENABLE_STD_THREAD
#include <mutex>
#endif // DMLC_ENABLE_STD_THREAD
#include <memory>
#include <vector>
namespace rabit {
// macro hanlding for threadlocal variables
#ifdef __GNUC__
#define MX_TREAD_LOCAL __thread
#elif __STDC_VERSION__ >= 201112L
#define MX_TREAD_LOCAL _Thread_local
#elif defined(_MSC_VER)
#define MX_TREAD_LOCAL __declspec(thread)
#endif // __GNUC__
#ifndef MX_TREAD_LOCAL
#message("Warning: Threadlocal is not enabled");
#endif // MX_TREAD_LOCAL
/*!
* \brief A threadlocal store to store threadlocal variables.
* Will return a thread local singleton of type T
* \tparam T the type we like to store
*/
template<typename T>
class ThreadLocalStore {
public:
/*! \return get a thread local singleton */
static T* Get() {
static MX_TREAD_LOCAL T* ptr = nullptr;
if (ptr == nullptr) {
ptr = new T();
Singleton()->RegisterDelete(ptr);
}
return ptr;
}
private:
/*! \brief constructor */
ThreadLocalStore() {}
/*! \brief destructor */
~ThreadLocalStore() {
for (size_t i = 0; i < data_.size(); ++i) {
delete data_[i];
}
}
/*! \return singleton of the store */
static ThreadLocalStore<T> *Singleton() {
static ThreadLocalStore<T> inst;
return &inst;
}
/*!
* \brief register str for internal deletion
* \param str the string pointer
*/
void RegisterDelete(T *str) {
#if DMLC_ENABLE_STD_THREAD
std::unique_lock<std::mutex> lock(mutex_);
data_.push_back(str);
lock.unlock();
#else
data_.push_back(str);
#endif // DMLC_ENABLE_STD_THREAD
}
#if DMLC_ENABLE_STD_THREAD
/*! \brief internal mutex */
std::mutex mutex_;
#endif // DMLC_ENABLE_STD_THREAD
/*!\brief internal data */
std::vector<T*> data_;
};
} // namespace rabit
#endif // RABIT_INTERNAL_THREAD_LOCAL_H_

View File

@ -6,7 +6,7 @@
*/ */
#ifndef RABIT_INTERNAL_TIMER_H_ #ifndef RABIT_INTERNAL_TIMER_H_
#define RABIT_INTERNAL_TIMER_H_ #define RABIT_INTERNAL_TIMER_H_
#include <time.h> #include <ctime>
#ifdef __MACH__ #ifdef __MACH__
#include <mach/clock.h> #include <mach/clock.h>
#include <mach/mach.h> #include <mach/mach.h>
@ -18,7 +18,7 @@ namespace utils {
/*! /*!
* \brief return time in seconds, not cross platform, avoid to use this in most places * \brief return time in seconds, not cross platform, avoid to use this in most places
*/ */
inline double GetTime(void) { inline double GetTime() {
#ifdef __MACH__ #ifdef __MACH__
clock_serv_t cclock; clock_serv_t cclock;
mach_timespec_t mts; mach_timespec_t mts;

View File

@ -8,7 +8,7 @@
#define RABIT_INTERNAL_UTILS_H_ #define RABIT_INTERNAL_UTILS_H_
#include <rabit/base.h> #include <rabit/base.h>
#include <string.h> #include <cstring>
#include <cstdio> #include <cstdio>
#include <string> #include <string>
#include <cstdlib> #include <cstdlib>
@ -48,15 +48,7 @@ extern "C" {
} }
#endif // _MSC_VER #endif // _MSC_VER
#ifdef _MSC_VER #include <cinttypes>
typedef unsigned char uint8_t;
typedef unsigned __int16 uint16_t;
typedef unsigned __int32 uint32_t;
typedef unsigned __int64 uint64_t;
typedef __int64 int64_t;
#else
#include <inttypes.h>
#endif // _MSC_VER
namespace rabit { namespace rabit {
/*! \brief namespace for helper utils of the project */ /*! \brief namespace for helper utils of the project */
@ -184,7 +176,7 @@ inline void Error(const char *fmt, ...) {
/*! \brief replace fopen, report error when the file open fails */ /*! \brief replace fopen, report error when the file open fails */
inline std::FILE *FopenCheck(const char *fname, const char *flag) { inline std::FILE *FopenCheck(const char *fname, const char *flag) {
std::FILE *fp = fopen64(fname, flag); std::FILE *fp = fopen64(fname, flag);
Check(fp != NULL, "can not open file \"%s\"\n", fname); Check(fp != nullptr, "can not open file \"%s\"\n", fname);
return fp; return fp;
} }
} // namespace utils } // namespace utils
@ -193,7 +185,7 @@ inline std::FILE *FopenCheck(const char *fname, const char *flag) {
template<typename T> template<typename T>
inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*) inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*)
if (vec.size() == 0) { if (vec.size() == 0) {
return NULL; return nullptr;
} else { } else {
return &vec[0]; return &vec[0];
} }
@ -202,17 +194,17 @@ inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*)
template<typename T> template<typename T>
inline const T *BeginPtr(const std::vector<T> &vec) { // NOLINT(*) inline const T *BeginPtr(const std::vector<T> &vec) { // NOLINT(*)
if (vec.size() == 0) { if (vec.size() == 0) {
return NULL; return nullptr;
} else { } else {
return &vec[0]; return &vec[0];
} }
} }
inline char* BeginPtr(std::string &str) { // NOLINT(*) inline char* BeginPtr(std::string &str) { // NOLINT(*)
if (str.length() == 0) return NULL; if (str.length() == 0) return nullptr;
return &str[0]; return &str[0];
} }
inline const char* BeginPtr(const std::string &str) { inline const char* BeginPtr(const std::string &str) {
if (str.length() == 0) return NULL; if (str.length() == 0) return nullptr;
return &str[0]; return &str[0];
} }
} // namespace rabit } // namespace rabit

View File

@ -53,12 +53,12 @@ namespace rabit {
* \brief defines stream used in rabit * \brief defines stream used in rabit
* see definition of Stream in dmlc/io.h * see definition of Stream in dmlc/io.h
*/ */
typedef dmlc::Stream Stream; using Stream = dmlc::Stream;
/*! /*!
* \brief defines serializable objects used in rabit * \brief defines serializable objects used in rabit
* see definition of Serializable in dmlc/io.h * see definition of Serializable in dmlc/io.h
*/ */
typedef dmlc::Serializable Serializable; using Serializable = dmlc::Serializable;
/*! /*!
* \brief reduction operators namespace * \brief reduction operators namespace
@ -199,8 +199,8 @@ inline void Broadcast(std::string *sendrecv_data, int root,
*/ */
template<typename OP, typename DType> template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count, inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *) = NULL, void (*prepare_fun)(void *) = nullptr,
void *prepare_arg = NULL, void *prepare_arg = nullptr,
const char* _file = _FILE, const char* _file = _FILE,
const int _line = _LINE, const int _line = _LINE,
const char* _caller = _CALLER); const char* _caller = _CALLER);
@ -220,7 +220,7 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
* \param _file caller file name used to generate unique cache key * \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key * \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key * \param _caller caller function name used to generate unique cache key
*/ */
template<typename DType> template<typename DType>
inline void Allgather(DType *sendrecvbuf_, inline void Allgather(DType *sendrecvbuf_,
size_t total_size, size_t total_size,
@ -291,7 +291,7 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
* \sa CheckPoint, VersionNumber * \sa CheckPoint, VersionNumber
*/ */
inline int LoadCheckPoint(Serializable *global_model, inline int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL); Serializable *local_model = nullptr);
/*! /*!
* \brief checkpoints the model, meaning a stage of execution has finished. * \brief checkpoints the model, meaning a stage of execution has finished.
* every time we call check point, a version number will be increased by one * every time we call check point, a version number will be increased by one
@ -307,7 +307,7 @@ inline int LoadCheckPoint(Serializable *global_model,
* \sa LoadCheckPoint, VersionNumber * \sa LoadCheckPoint, VersionNumber
*/ */
inline void CheckPoint(const Serializable *global_model, inline void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL); const Serializable *local_model = nullptr);
/*! /*!
* \brief This function can be used to replace CheckPoint for global_model only, * \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met (see detailed explanation). * when certain condition is met (see detailed explanation).
@ -366,8 +366,8 @@ class Reducer {
* \param _caller caller function name used to generate unique cache key * \param _caller caller function name used to generate unique cache key
*/ */
inline void Allreduce(DType *sendrecvbuf, size_t count, inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *) = NULL, void (*prepare_fun)(void *) = nullptr,
void *prepare_arg = NULL, void *prepare_arg = nullptr,
const char* _file = _FILE, const char* _file = _FILE,
const int _line = _LINE, const int _line = _LINE,
const char* _caller = _CALLER); const char* _caller = _CALLER);
@ -422,8 +422,8 @@ class SerializeReducer {
*/ */
inline void Allreduce(DType *sendrecvobj, inline void Allreduce(DType *sendrecvobj,
size_t max_nbyte, size_t count, size_t max_nbyte, size_t count,
void (*prepare_fun)(void *) = NULL, void (*prepare_fun)(void *) = nullptr,
void *prepare_arg = NULL, void *prepare_arg = nullptr,
const char* _file = _FILE, const char* _file = _FILE,
const int _line = _LINE, const int _line = _LINE,
const char* _caller = _CALLER); const char* _caller = _CALLER);

View File

@ -15,12 +15,12 @@ namespace rabit {
* \brief defines stream used in rabit * \brief defines stream used in rabit
* see definition of Stream in dmlc/io.h * see definition of Stream in dmlc/io.h
*/ */
typedef dmlc::Stream Stream; using Stream = dmlc::Stream ;
/*! /*!
* \brief defines serializable objects used in rabit * \brief defines serializable objects used in rabit
* see definition of Serializable in dmlc/io.h * see definition of Serializable in dmlc/io.h
*/ */
typedef dmlc::Serializable Serializable; using Serializable = dmlc::Serializable;
} // namespace rabit } // namespace rabit
#endif // RABIT_SERIALIZABLE_H_ #endif // RABIT_SERIALIZABLE_H_

View File

@ -15,7 +15,7 @@
namespace rabit { namespace rabit {
namespace engine { namespace engine {
// constructor // constructor
AllreduceBase::AllreduceBase(void) { AllreduceBase::AllreduceBase() {
tracker_uri = "NULL"; tracker_uri = "NULL";
tracker_port = 9000; tracker_port = 9000;
host_uri = ""; host_uri = "";
@ -24,7 +24,7 @@ AllreduceBase::AllreduceBase(void) {
rank = 0; rank = 0;
world_size = -1; world_size = -1;
connect_retry = 5; connect_retry = 5;
hadoop_mode = 0; hadoop_mode = false;
version_number = 0; version_number = 0;
// 32 K items // 32 K items
reduce_ring_mincount = 32 << 10; reduce_ring_mincount = 32 << 10;
@ -32,27 +32,27 @@ AllreduceBase::AllreduceBase(void) {
tree_reduce_minsize = 1 << 20; tree_reduce_minsize = 1 << 20;
// tracker URL // tracker URL
task_id = "NULL"; task_id = "NULL";
err_link = NULL; err_link = nullptr;
dmlc_role = "worker"; dmlc_role = "worker";
this->SetParam("rabit_reduce_buffer", "256MB"); this->SetParam("rabit_reduce_buffer", "256MB");
// setup possible enviroment variable of interest // setup possible enviroment variable of interest
// include dmlc support direct variables // include dmlc support direct variables
env_vars.push_back("DMLC_TASK_ID"); env_vars.emplace_back("DMLC_TASK_ID");
env_vars.push_back("DMLC_ROLE"); env_vars.emplace_back("DMLC_ROLE");
env_vars.push_back("DMLC_NUM_ATTEMPT"); env_vars.emplace_back("DMLC_NUM_ATTEMPT");
env_vars.push_back("DMLC_TRACKER_URI"); env_vars.emplace_back("DMLC_TRACKER_URI");
env_vars.push_back("DMLC_TRACKER_PORT"); env_vars.emplace_back("DMLC_TRACKER_PORT");
env_vars.push_back("DMLC_WORKER_CONNECT_RETRY"); env_vars.emplace_back("DMLC_WORKER_CONNECT_RETRY");
} }
// initialization function // initialization function
bool AllreduceBase::Init(int argc, char* argv[]) { bool AllreduceBase::Init(int argc, char* argv[]) {
// setup from enviroment variables // setup from enviroment variables
// handler to get variables from env // handler to get variables from env
for (size_t i = 0; i < env_vars.size(); ++i) { for (auto & env_var : env_vars) {
const char *value = getenv(env_vars[i].c_str()); const char *value = getenv(env_var.c_str());
if (value != NULL) { if (value != nullptr) {
this->SetParam(env_vars[i].c_str(), value); this->SetParam(env_var.c_str(), value);
} }
} }
// pass in arguments override env variable. // pass in arguments override env variable.
@ -66,35 +66,35 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
{ {
// handling for hadoop // handling for hadoop
const char *task_id = getenv("mapred_tip_id"); const char *task_id = getenv("mapred_tip_id");
if (task_id == NULL) { if (task_id == nullptr) {
task_id = getenv("mapreduce_task_id"); task_id = getenv("mapreduce_task_id");
} }
if (hadoop_mode) { if (hadoop_mode) {
utils::Check(task_id != NULL, utils::Check(task_id != nullptr,
"hadoop_mode is set but cannot find mapred_task_id"); "hadoop_mode is set but cannot find mapred_task_id");
} }
if (task_id != NULL) { if (task_id != nullptr) {
this->SetParam("rabit_task_id", task_id); this->SetParam("rabit_task_id", task_id);
this->SetParam("rabit_hadoop_mode", "1"); this->SetParam("rabit_hadoop_mode", "1");
} }
const char *attempt_id = getenv("mapred_task_id"); const char *attempt_id = getenv("mapred_task_id");
if (attempt_id != 0) { if (attempt_id != nullptr) {
const char *att = strrchr(attempt_id, '_'); const char *att = strrchr(attempt_id, '_');
int num_trial; int num_trial;
if (att != NULL && sscanf(att + 1, "%d", &num_trial) == 1) { if (att != nullptr && sscanf(att + 1, "%d", &num_trial) == 1) {
this->SetParam("rabit_num_trial", att + 1); this->SetParam("rabit_num_trial", att + 1);
} }
} }
// handling for hadoop // handling for hadoop
const char *num_task = getenv("mapred_map_tasks"); const char *num_task = getenv("mapred_map_tasks");
if (num_task == NULL) { if (num_task == nullptr) {
num_task = getenv("mapreduce_job_maps"); num_task = getenv("mapreduce_job_maps");
} }
if (hadoop_mode) { if (hadoop_mode) {
utils::Check(num_task != NULL, utils::Check(num_task != nullptr,
"hadoop_mode is set but cannot find mapred_map_tasks"); "hadoop_mode is set but cannot find mapred_map_tasks");
} }
if (num_task != NULL) { if (num_task != nullptr) {
this->SetParam("rabit_world_size", num_task); this->SetParam("rabit_world_size", num_task);
} }
} }
@ -115,10 +115,10 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
return this->ReConnectLinks(); return this->ReConnectLinks();
} }
bool AllreduceBase::Shutdown(void) { bool AllreduceBase::Shutdown() {
try { try {
for (size_t i = 0; i < all_links.size(); ++i) { for (auto & all_link : all_links) {
all_links[i].sock.Close(); all_link.sock.Close();
} }
all_links.clear(); all_links.clear();
tree_links.plinks.clear(); tree_links.plinks.clear();
@ -208,17 +208,18 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
utils::Assert(timeout_sec >= 0, "rabit_timeout_sec should be non negative second"); utils::Assert(timeout_sec >= 0, "rabit_timeout_sec should be non negative second");
} }
if (!strcmp(name, "rabit_enable_tcp_no_delay")) { if (!strcmp(name, "rabit_enable_tcp_no_delay")) {
if (!strcmp(val, "true")) if (!strcmp(val, "true")) {
rabit_enable_tcp_no_delay = true; rabit_enable_tcp_no_delay = true;
else } else {
rabit_enable_tcp_no_delay = false; rabit_enable_tcp_no_delay = false;
}
} }
} }
/*! /*!
* \brief initialize connection to the tracker * \brief initialize connection to the tracker
* \return a socket that initializes the connection * \return a socket that initializes the connection
*/ */
utils::TCPSocket AllreduceBase::ConnectTracker(void) const { utils::TCPSocket AllreduceBase::ConnectTracker() const {
int magic = kMagic; int magic = kMagic;
// get information from tracker // get information from tracker
utils::TCPSocket tracker; utils::TCPSocket tracker;
@ -241,7 +242,7 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
} }
} }
break; break;
} while (1); } while (true);
using utils::Assert; using utils::Assert;
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic), Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
@ -320,19 +321,19 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
do { do {
// send over good links // send over good links
std::vector<int> good_link; std::vector<int> good_link;
for (size_t i = 0; i < all_links.size(); ++i) { for (auto & all_link : all_links) {
if (!all_links[i].sock.BadSocket()) { if (!all_link.sock.BadSocket()) {
good_link.push_back(static_cast<int>(all_links[i].rank)); good_link.push_back(static_cast<int>(all_link.rank));
} else { } else {
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close(); if (!all_link.sock.IsClosed()) all_link.sock.Close();
} }
} }
int ngood = static_cast<int>(good_link.size()); int ngood = static_cast<int>(good_link.size());
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood), Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
"ReConnectLink failure 5"); "ReConnectLink failure 5");
for (size_t i = 0; i < good_link.size(); ++i) { for (int & i : good_link) {
Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \ Assert(tracker.SendAll(&i, sizeof(i)) == \
sizeof(good_link[i]), "ReConnectLink failure 6"); sizeof(i), "ReConnectLink failure 6");
} }
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn), Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
"ReConnectLink failure 7"); "ReConnectLink failure 7");
@ -362,11 +363,11 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
utils::Check(hrank == r.rank, utils::Check(hrank == r.rank,
"ReConnectLink failure, link rank inconsistent"); "ReConnectLink failure, link rank inconsistent");
bool match = false; bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) { for (auto & all_link : all_links) {
if (all_links[i].rank == hrank) { if (all_link.rank == hrank) {
Assert(all_links[i].sock.IsClosed(), Assert(all_link.sock.IsClosed(),
"Override a link that is active"); "Override a link that is active");
all_links[i].sock = r.sock; all_link.sock = r.sock;
match = true; match = true;
break; break;
} }
@ -390,11 +391,11 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
"ReConnectLink failure 15"); "ReConnectLink failure 15");
bool match = false; bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) { for (auto & all_link : all_links) {
if (all_links[i].rank == r.rank) { if (all_link.rank == r.rank) {
utils::Assert(all_links[i].sock.IsClosed(), utils::Assert(all_link.sock.IsClosed(),
"Override a link that is active"); "Override a link that is active");
all_links[i].sock = r.sock; all_link.sock = r.sock;
match = true; match = true;
break; break;
} }
@ -406,29 +407,29 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
// setup tree links and ring structure // setup tree links and ring structure
tree_links.plinks.clear(); tree_links.plinks.clear();
int tcpNoDelay = 1; int tcpNoDelay = 1;
for (size_t i = 0; i < all_links.size(); ++i) { for (auto & all_link : all_links) {
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket"); utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
// set the socket to non-blocking mode, enable TCP keepalive // set the socket to non-blocking mode, enable TCP keepalive
all_links[i].sock.SetNonBlock(true); all_link.sock.SetNonBlock(true);
all_links[i].sock.SetKeepAlive(true); all_link.sock.SetKeepAlive(true);
if (rabit_enable_tcp_no_delay) { if (rabit_enable_tcp_no_delay) {
setsockopt(all_links[i].sock, IPPROTO_TCP, setsockopt(all_link.sock, IPPROTO_TCP,
TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay)); TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay));
} }
if (tree_neighbors.count(all_links[i].rank) != 0) { if (tree_neighbors.count(all_link.rank) != 0) {
if (all_links[i].rank == parent_rank) { if (all_link.rank == parent_rank) {
parent_index = static_cast<int>(tree_links.plinks.size()); parent_index = static_cast<int>(tree_links.plinks.size());
} }
tree_links.plinks.push_back(&all_links[i]); tree_links.plinks.push_back(&all_link);
} }
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i]; if (all_link.rank == prev_rank) ring_prev = &all_link;
if (all_links[i].rank == next_rank) ring_next = &all_links[i]; if (all_link.rank == next_rank) ring_next = &all_link;
} }
Assert(parent_rank == -1 || parent_index != -1, Assert(parent_rank == -1 || parent_index != -1,
"cannot find parent in the link"); "cannot find parent in the link");
Assert(prev_rank == -1 || ring_prev != NULL, Assert(prev_rank == -1 || ring_prev != nullptr,
"cannot find prev ring in the link"); "cannot find prev ring in the link");
Assert(next_rank == -1 || ring_next != NULL, Assert(next_rank == -1 || ring_next != nullptr,
"cannot find next ring in the link"); "cannot find next ring in the link");
return true; return true;
} catch (const std::exception& e) { } catch (const std::exception& e) {
@ -479,11 +480,11 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
size_t count, size_t count,
ReduceFunction reducer) { ReduceFunction reducer) {
RefLinkVector &links = tree_links; RefLinkVector &links = tree_links;
if (links.size() == 0 || count == 0) return kSuccess; if (links.Size() == 0 || count == 0) return kSuccess;
// total size of message // total size of message
const size_t total_size = type_nbytes * count; const size_t total_size = type_nbytes * count;
// number of links // number of links
const int nlink = static_cast<int>(links.size()); const int nlink = static_cast<int>(links.Size());
// send recv buffer // send recv buffer
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_); char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
// size of space that we already performs reduce in up pass // size of space that we already performs reduce in up pass
@ -581,8 +582,9 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
// if max reduce is less than total size, we reduce multiple times of // if max reduce is less than total size, we reduce multiple times of
// eachreduce size // eachreduce size
if (max_reduce < total_size) if (max_reduce < total_size) {
max_reduce = max_reduce - max_reduce % eachreduce; max_reduce = max_reduce - max_reduce % eachreduce;
}
// peform reduce, can be at most two rounds // peform reduce, can be at most two rounds
while (size_up_reduce < max_reduce) { while (size_up_reduce < max_reduce) {
@ -677,11 +679,11 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
AllreduceBase::ReturnType AllreduceBase::ReturnType
AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
RefLinkVector &links = tree_links; RefLinkVector &links = tree_links;
if (links.size() == 0 || total_size == 0) return kSuccess; if (links.Size() == 0 || total_size == 0) return kSuccess;
utils::Check(root < world_size, utils::Check(root < world_size,
"Broadcast: root should be smaller than world size"); "Broadcast: root should be smaller than world size");
// number of links // number of links
const int nlink = static_cast<int>(links.size()); const int nlink = static_cast<int>(links.Size());
// size of space already read from data // size of space already read from data
size_t size_in = 0; size_t size_in = 0;
// input link, -2 means unknown yet, -1 means this is root // input link, -2 means unknown yet, -1 means this is root

View File

@ -25,7 +25,7 @@
#endif // RABIT_CXXTESTDEFS_H #endif // RABIT_CXXTESTDEFS_H
namespace MPI { namespace MPI { // NOLINT
// MPI data type to be compatible with existing MPI interface // MPI data type to be compatible with existing MPI interface
class Datatype { class Datatype {
public: public:
@ -41,12 +41,12 @@ class AllreduceBase : public IEngine {
// magic number to verify server // magic number to verify server
static const int kMagic = 0xff99; static const int kMagic = 0xff99;
// constant one byte out of band message to indicate error happening // constant one byte out of band message to indicate error happening
AllreduceBase(void); AllreduceBase();
virtual ~AllreduceBase(void) {} virtual ~AllreduceBase() = default;
// initialize the manager // initialize the manager
virtual bool Init(int argc, char* argv[]); virtual bool Init(int argc, char* argv[]);
// shutdown the engine // shutdown the engine
virtual bool Shutdown(void); virtual bool Shutdown();
/*! /*!
* \brief set parameters to the engine * \brief set parameters to the engine
* \param name parameter name * \param name parameter name
@ -59,27 +59,27 @@ class AllreduceBase : public IEngine {
* the user who monitors the tracker * the user who monitors the tracker
* \param msg message to be printed in the tracker * \param msg message to be printed in the tracker
*/ */
virtual void TrackerPrint(const std::string &msg); void TrackerPrint(const std::string &msg) override;
/*! \brief get rank of previous node in ring topology*/ /*! \brief get rank of previous node in ring topology*/
virtual int GetRingPrevRank(void) const { int GetRingPrevRank() const override {
return ring_prev->rank; return ring_prev->rank;
} }
/*! \brief get rank */ /*! \brief get rank */
virtual int GetRank(void) const { int GetRank() const override {
return rank; return rank;
} }
/*! \brief get rank */ /*! \brief get rank */
virtual int GetWorldSize(void) const { int GetWorldSize() const override {
if (world_size == -1) return 1; if (world_size == -1) return 1;
return world_size; return world_size;
} }
/*! \brief whether is distributed or not */ /*! \brief whether is distributed or not */
virtual bool IsDistributed(void) const { bool IsDistributed() const override {
return tracker_uri != "NULL"; return tracker_uri != "NULL";
} }
/*! \brief get rank */ /*! \brief get rank */
virtual std::string GetHost(void) const { std::string GetHost() const override {
return host_uri; return host_uri;
} }
@ -99,13 +99,10 @@ class AllreduceBase : public IEngine {
* \param _line caller line number used to generate unique cache key * \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key * \param _caller caller function name used to generate unique cache key
*/ */
virtual void Allgather(void *sendrecvbuf_, size_t total_size, void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
size_t slice_begin, size_t slice_end, size_t size_prev_slice,
size_t slice_end, const char *_file = _FILE, const int _line = _LINE,
size_t size_prev_slice, const char *_caller = _CALLER) override {
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
if (world_size == 1 || world_size == -1) return; if (world_size == 1 || world_size == -1) return;
utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size, utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size,
slice_begin, slice_end, size_prev_slice) == kSuccess, slice_begin, slice_end, size_prev_slice) == kSuccess,
@ -126,16 +123,12 @@ class AllreduceBase : public IEngine {
* \param _line caller line number used to generate unique cache key * \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key * \param _caller caller function name used to generate unique cache key
*/ */
virtual void Allreduce(void *sendrecvbuf_, void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
size_t type_nbytes, ReduceFunction reducer, PreprocFunction prepare_fun = nullptr,
size_t count, void *prepare_arg = nullptr, const char *_file = _FILE,
ReduceFunction reducer, const int _line = _LINE,
PreprocFunction prepare_fun = NULL, const char *_caller = _CALLER) override {
void *prepare_arg = NULL, if (prepare_fun != nullptr) prepare_fun(prepare_arg);
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
if (world_size == 1 || world_size == -1) return; if (world_size == 1 || world_size == -1) return;
utils::Assert(TryAllreduce(sendrecvbuf_, utils::Assert(TryAllreduce(sendrecvbuf_,
type_nbytes, count, reducer) == kSuccess, type_nbytes, count, reducer) == kSuccess,
@ -150,8 +143,9 @@ class AllreduceBase : public IEngine {
* \param _line caller line number used to generate unique cache key * \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key * \param _caller caller function name used to generate unique cache key
*/ */
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root, void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char* _file = _FILE, const int _line = _LINE, const char* _caller = _CALLER) { const char *_file = _FILE, const int _line = _LINE,
const char *_caller = _CALLER) override {
if (world_size == 1 || world_size == -1) return; if (world_size == 1 || world_size == -1) return;
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess, utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
"Broadcast failed"); "Broadcast failed");
@ -178,8 +172,8 @@ class AllreduceBase : public IEngine {
* *
* \sa CheckPoint, VersionNumber * \sa CheckPoint, VersionNumber
*/ */
virtual int LoadCheckPoint(Serializable *global_model, int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL) { Serializable *local_model = nullptr) override {
return 0; return 0;
} }
/*! /*!
@ -198,8 +192,8 @@ class AllreduceBase : public IEngine {
* *
* \sa LoadCheckPoint, VersionNumber * \sa LoadCheckPoint, VersionNumber
*/ */
virtual void CheckPoint(const Serializable *global_model, void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL) { const Serializable *local_model = nullptr) override {
version_number += 1; version_number += 1;
} }
/*! /*!
@ -222,7 +216,7 @@ class AllreduceBase : public IEngine {
* is the same in all nodes * is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber * \sa LoadCheckPoint, CheckPoint, VersionNumber
*/ */
virtual void LazyCheckPoint(const Serializable *global_model) { void LazyCheckPoint(const Serializable *global_model) override {
version_number += 1; version_number += 1;
} }
/*! /*!
@ -230,7 +224,7 @@ class AllreduceBase : public IEngine {
* which means how many calls to CheckPoint we made so far * which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint * \sa LoadCheckPoint, CheckPoint
*/ */
virtual int VersionNumber(void) const { int VersionNumber() const override {
return version_number; return version_number;
} }
/*! /*!
@ -238,14 +232,14 @@ class AllreduceBase : public IEngine {
* call this function when IEngine throw an exception out, * call this function when IEngine throw an exception out,
* this function is only used for test purpose * this function is only used for test purpose
*/ */
virtual void InitAfterException(void) { void InitAfterException() override {
utils::Error("InitAfterException: not implemented"); utils::Error("InitAfterException: not implemented");
} }
/*! /*!
* \brief report current status to the job tracker * \brief report current status to the job tracker
* depending on the job tracker we are in * depending on the job tracker we are in
*/ */
inline void ReportStatus(void) const { inline void ReportStatus() const {
if (hadoop_mode != 0) { if (hadoop_mode != 0) {
fprintf(stderr, "reporter:status:Rabit Phase[%03d] Operation %03d\n", fprintf(stderr, "reporter:status:Rabit Phase[%03d] Operation %03d\n",
version_number, seq_counter); version_number, seq_counter);
@ -274,7 +268,7 @@ class AllreduceBase : public IEngine {
/*! \brief internal return type */ /*! \brief internal return type */
ReturnTypeEnum value; ReturnTypeEnum value;
// constructor // constructor
ReturnType() {} ReturnType() = default;
ReturnType(ReturnTypeEnum value) : value(value) {} // NOLINT(*) ReturnType(ReturnTypeEnum value) : value(value) {} // NOLINT(*)
inline bool operator==(const ReturnTypeEnum &v) const { inline bool operator==(const ReturnTypeEnum &v) const {
return value == v; return value == v;
@ -306,13 +300,11 @@ class AllreduceBase : public IEngine {
// size of data sent to the link // size of data sent to the link
size_t size_write; size_t size_write;
// pointer to buffer head // pointer to buffer head
char *buffer_head; char *buffer_head {nullptr};
// buffer size, in bytes // buffer size, in bytes
size_t buffer_size; size_t buffer_size {0};
// constructor // constructor
LinkRecord(void) LinkRecord() = default;
: buffer_head(NULL), buffer_size(0) {
}
// initialize buffer // initialize buffer
inline void InitBuffer(size_t type_nbytes, size_t count, inline void InitBuffer(size_t type_nbytes, size_t count,
size_t reduce_buffer_size) { size_t reduce_buffer_size) {
@ -328,7 +320,7 @@ class AllreduceBase : public IEngine {
buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_)); buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
} }
// reset the recv and sent size // reset the recv and sent size
inline void ResetSize(void) { inline void ResetSize() {
size_write = size_read = 0; size_write = size_read = 0;
} }
/*! /*!
@ -340,7 +332,7 @@ class AllreduceBase : public IEngine {
* \return the type of reading * \return the type of reading
*/ */
inline ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read) { inline ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read) {
utils::Assert(buffer_head != NULL, "ReadToRingBuffer: buffer not allocated"); utils::Assert(buffer_head != nullptr, "ReadToRingBuffer: buffer not allocated");
utils::Assert(size_read <= max_size_read, "ReadToRingBuffer: max_size_read check"); utils::Assert(size_read <= max_size_read, "ReadToRingBuffer: max_size_read check");
size_t ngap = size_read - protect_start; size_t ngap = size_read - protect_start;
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check"); utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
@ -405,7 +397,7 @@ class AllreduceBase : public IEngine {
inline LinkRecord &operator[](size_t i) { inline LinkRecord &operator[](size_t i) {
return *plinks[i]; return *plinks[i];
} }
inline size_t size(void) const { inline size_t Size() const {
return plinks.size(); return plinks.size();
} }
}; };
@ -413,7 +405,7 @@ class AllreduceBase : public IEngine {
* \brief initialize connection to the tracker * \brief initialize connection to the tracker
* \return a socket that initializes the connection * \return a socket that initializes the connection
*/ */
utils::TCPSocket ConnectTracker(void) const; utils::TCPSocket ConnectTracker() const;
/*! /*!
* \brief connect to the tracker to fix the the missing links * \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up * this function is also used when the engine start up
@ -525,64 +517,64 @@ class AllreduceBase : public IEngine {
//---- data structure related to model ---- //---- data structure related to model ----
// call sequence counter, records how many calls we made so far // call sequence counter, records how many calls we made so far
// from last call to CheckPoint, LoadCheckPoint // from last call to CheckPoint, LoadCheckPoint
int seq_counter; int seq_counter; // NOLINT
// version number of model // version number of model
int version_number; int version_number; // NOLINT
// whether the job is running in hadoop // whether the job is running in hadoop
bool hadoop_mode; bool hadoop_mode; // NOLINT
//---- local data related to link ---- //---- local data related to link ----
// index of parent link, can be -1, meaning this is root of the tree // index of parent link, can be -1, meaning this is root of the tree
int parent_index; int parent_index; // NOLINT
// rank of parent node, can be -1 // rank of parent node, can be -1
int parent_rank; int parent_rank; // NOLINT
// sockets of all links this connects to // sockets of all links this connects to
std::vector<LinkRecord> all_links; std::vector<LinkRecord> all_links; // NOLINT
// used to record the link where things goes wrong // used to record the link where things goes wrong
LinkRecord *err_link; LinkRecord *err_link; // NOLINT
// all the links in the reduction tree connection // all the links in the reduction tree connection
RefLinkVector tree_links; RefLinkVector tree_links; // NOLINT
// pointer to links in the ring // pointer to links in the ring
LinkRecord *ring_prev, *ring_next; LinkRecord *ring_prev, *ring_next; // NOLINT
//----- meta information----- //----- meta information-----
// list of enviroment variables that are of possible interest // list of enviroment variables that are of possible interest
std::vector<std::string> env_vars; std::vector<std::string> env_vars; // NOLINT
// unique identifier of the possible job this process is doing // unique identifier of the possible job this process is doing
// used to assign ranks, optional, default to NULL // used to assign ranks, optional, default to NULL
std::string task_id; std::string task_id; // NOLINT
// uri of current host, to be set by Init // uri of current host, to be set by Init
std::string host_uri; std::string host_uri; // NOLINT
// uri of tracker // uri of tracker
std::string tracker_uri; std::string tracker_uri; // NOLINT
// role in dmlc jobs // role in dmlc jobs
std::string dmlc_role; std::string dmlc_role; // NOLINT
// port of tracker address // port of tracker address
int tracker_port; int tracker_port; // NOLINT
// port of slave process // port of slave process
int slave_port, nport_trial; int slave_port, nport_trial; // NOLINT
// reduce buffer size // reduce buffer size
size_t reduce_buffer_size; size_t reduce_buffer_size; // NOLINT
// reduction method // reduction method
int reduce_method; int reduce_method; // NOLINT
// mininum count of cells to use ring based method // mininum count of cells to use ring based method
size_t reduce_ring_mincount; size_t reduce_ring_mincount; // NOLINT
// minimul block size per tree reduce // minimul block size per tree reduce
size_t tree_reduce_minsize; size_t tree_reduce_minsize; // NOLINT
// current rank // current rank
int rank; int rank; // NOLINT
// world size // world size
int world_size; int world_size; // NOLINT
// connect retry time // connect retry time
int connect_retry; int connect_retry; // NOLINT
// enable bootstrap cache 0 false 1 true // enable bootstrap cache 0 false 1 true
bool rabit_bootstrap_cache = false; bool rabit_bootstrap_cache = false; // NOLINT
// enable detailed logging // enable detailed logging
bool rabit_debug = false; bool rabit_debug = false; // NOLINT
// by default, if rabit worker not recover in half an hour exit // by default, if rabit worker not recover in half an hour exit
int timeout_sec = 1800; int timeout_sec = 1800; // NOLINT
// flag to enable rabit_timeout // flag to enable rabit_timeout
bool rabit_timeout = false; bool rabit_timeout = false; // NOLINT
// Enable TCP node delay // Enable TCP node delay
bool rabit_enable_tcp_no_delay = false; bool rabit_enable_tcp_no_delay = false; // NOLINT
}; };
} // namespace engine } // namespace engine
} // namespace rabit } // namespace rabit

View File

@ -20,74 +20,65 @@ namespace engine {
class AllreduceMock : public AllreduceRobust { class AllreduceMock : public AllreduceRobust {
public: public:
// constructor // constructor
AllreduceMock(void) { AllreduceMock() {
num_trial = 0; num_trial_ = 0;
force_local = 0; force_local_ = 0;
report_stats = 0; report_stats_ = 0;
tsum_allreduce = 0.0; tsum_allreduce_ = 0.0;
tsum_allgather = 0.0; tsum_allgather_ = 0.0;
} }
// destructor // destructor
virtual ~AllreduceMock(void) {} ~AllreduceMock() override = default;
virtual void SetParam(const char *name, const char *val) { void SetParam(const char *name, const char *val) override {
AllreduceRobust::SetParam(name, val); AllreduceRobust::SetParam(name, val);
// additional parameters // additional parameters
if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val); if (!strcmp(name, "rabit_num_trial")) num_trial_ = atoi(val);
if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial = atoi(val); if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial_ = atoi(val);
if (!strcmp(name, "report_stats")) report_stats = atoi(val); if (!strcmp(name, "report_stats")) report_stats_ = atoi(val);
if (!strcmp(name, "force_local")) force_local = atoi(val); if (!strcmp(name, "force_local")) force_local_ = atoi(val);
if (!strcmp(name, "mock")) { if (!strcmp(name, "mock")) {
MockKey k; MockKey k;
utils::Check(sscanf(val, "%d,%d,%d,%d", utils::Check(sscanf(val, "%d,%d,%d,%d",
&k.rank, &k.version, &k.seqno, &k.ntrial) == 4, &k.rank, &k.version, &k.seqno, &k.ntrial) == 4,
"invalid mock parameter"); "invalid mock parameter");
mock_map[k] = 1; mock_map_[k] = 1;
} }
} }
virtual void Allreduce(void *sendrecvbuf_, void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
size_t type_nbytes, ReduceFunction reducer, PreprocFunction prepare_fun,
size_t count, void *prepare_arg, const char *_file = _FILE,
ReduceFunction reducer, const int _line = _LINE,
PreprocFunction prepare_fun, const char *_caller = _CALLER) override {
void *prepare_arg, this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "AllReduce");
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
double tstart = utils::GetTime(); double tstart = utils::GetTime();
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes, AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
count, reducer, prepare_fun, prepare_arg, count, reducer, prepare_fun, prepare_arg,
_file, _line, _caller); _file, _line, _caller);
tsum_allreduce += utils::GetTime() - tstart; tsum_allreduce_ += utils::GetTime() - tstart;
} }
virtual void Allgather(void *sendrecvbuf, void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin,
size_t total_size, size_t slice_end, size_t size_prev_slice,
size_t slice_begin, const char *_file = _FILE, const int _line = _LINE,
size_t slice_end, const char *_caller = _CALLER) override {
size_t size_prev_slice, this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Allgather");
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Allgather");
double tstart = utils::GetTime(); double tstart = utils::GetTime();
AllreduceRobust::Allgather(sendrecvbuf, total_size, AllreduceRobust::Allgather(sendrecvbuf, total_size,
slice_begin, slice_end, slice_begin, slice_end,
size_prev_slice, _file, _line, _caller); size_prev_slice, _file, _line, _caller);
tsum_allgather += utils::GetTime() - tstart; tsum_allgather_ += utils::GetTime() - tstart;
} }
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root, void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char* _file = _FILE, const char *_file = _FILE, const int _line = _LINE,
const int _line = _LINE, const char *_caller = _CALLER) override {
const char* _caller = _CALLER) { this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Broadcast");
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller); AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller);
} }
virtual int LoadCheckPoint(Serializable *global_model, int LoadCheckPoint(Serializable *global_model,
Serializable *local_model) { Serializable *local_model) override {
tsum_allreduce = 0.0; tsum_allreduce_ = 0.0;
tsum_allgather = 0.0; tsum_allgather_ = 0.0;
time_checkpoint = utils::GetTime(); time_checkpoint_ = utils::GetTime();
if (force_local == 0) { if (force_local_ == 0) {
return AllreduceRobust::LoadCheckPoint(global_model, local_model); return AllreduceRobust::LoadCheckPoint(global_model, local_model);
} else { } else {
DummySerializer dum; DummySerializer dum;
@ -95,56 +86,54 @@ class AllreduceMock : public AllreduceRobust {
return AllreduceRobust::LoadCheckPoint(&dum, &com); return AllreduceRobust::LoadCheckPoint(&dum, &com);
} }
} }
virtual void CheckPoint(const Serializable *global_model, void CheckPoint(const Serializable *global_model,
const Serializable *local_model) { const Serializable *local_model) override {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint"); this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "CheckPoint");
double tstart = utils::GetTime(); double tstart = utils::GetTime();
double tbet_chkpt = tstart - time_checkpoint; double tbet_chkpt = tstart - time_checkpoint_;
if (force_local == 0) { if (force_local_ == 0) {
AllreduceRobust::CheckPoint(global_model, local_model); AllreduceRobust::CheckPoint(global_model, local_model);
} else { } else {
DummySerializer dum; DummySerializer dum;
ComboSerializer com(global_model, local_model); ComboSerializer com(global_model, local_model);
AllreduceRobust::CheckPoint(&dum, &com); AllreduceRobust::CheckPoint(&dum, &com);
} }
time_checkpoint = utils::GetTime(); time_checkpoint_ = utils::GetTime();
double tcost = utils::GetTime() - tstart; double tcost = utils::GetTime() - tstart;
if (report_stats != 0 && rank == 0) { if (report_stats_ != 0 && rank == 0) {
std::stringstream ss; std::stringstream ss;
ss << "[v" << version_number << "] global_size=" << global_checkpoint.length() ss << "[v" << version_number << "] global_size=" << global_checkpoint_.length()
<< ",local_size=" << (local_chkpt[0].length() + local_chkpt[1].length()) << ",local_size=" << (local_chkpt_[0].length() + local_chkpt_[1].length())
<< ",check_tcost="<< tcost <<" sec" << ",check_tcost="<< tcost <<" sec"
<< ",allreduce_tcost=" << tsum_allreduce << " sec" << ",allreduce_tcost=" << tsum_allreduce_ << " sec"
<< ",allgather_tcost=" << tsum_allgather << " sec" << ",allgather_tcost=" << tsum_allgather_ << " sec"
<< ",between_chpt=" << tbet_chkpt << "sec\n"; << ",between_chpt=" << tbet_chkpt << "sec\n";
this->TrackerPrint(ss.str()); this->TrackerPrint(ss.str());
} }
tsum_allreduce = 0.0; tsum_allreduce_ = 0.0;
tsum_allgather = 0.0; tsum_allgather_ = 0.0;
} }
virtual void LazyCheckPoint(const Serializable *global_model) { void LazyCheckPoint(const Serializable *global_model) override {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint"); this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "LazyCheckPoint");
AllreduceRobust::LazyCheckPoint(global_model); AllreduceRobust::LazyCheckPoint(global_model);
} }
protected: protected:
// force checkpoint to local // force checkpoint to local
int force_local; int force_local_;
// whether report statistics // whether report statistics
int report_stats; int report_stats_;
// sum of allreduce // sum of allreduce
double tsum_allreduce; double tsum_allreduce_;
// sum of allgather // sum of allgather
double tsum_allgather; double tsum_allgather_;
double time_checkpoint; double time_checkpoint_;
private: private:
struct DummySerializer : public Serializable { struct DummySerializer : public Serializable {
virtual void Load(Stream *fi) { void Load(Stream *fi) override {}
} void Save(Stream *fo) const override {}
virtual void Save(Stream *fo) const {
}
}; };
struct ComboSerializer : public Serializable { struct ComboSerializer : public Serializable {
Serializable *lhs; Serializable *lhs;
@ -155,15 +144,15 @@ class AllreduceMock : public AllreduceRobust {
: lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) { : lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) {
} }
ComboSerializer(const Serializable *lhs, const Serializable *rhs) ComboSerializer(const Serializable *lhs, const Serializable *rhs)
: lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) { : lhs(nullptr), rhs(nullptr), c_lhs(lhs), c_rhs(rhs) {
} }
virtual void Load(Stream *fi) { void Load(Stream *fi) override {
if (lhs != NULL) lhs->Load(fi); if (lhs != nullptr) lhs->Load(fi);
if (rhs != NULL) rhs->Load(fi); if (rhs != nullptr) rhs->Load(fi);
} }
virtual void Save(Stream *fo) const { void Save(Stream *fo) const override {
if (c_lhs != NULL) c_lhs->Save(fo); if (c_lhs != nullptr) c_lhs->Save(fo);
if (c_rhs != NULL) c_rhs->Save(fo); if (c_rhs != nullptr) c_rhs->Save(fo);
} }
}; };
// key to identify the mock stage // key to identify the mock stage
@ -172,7 +161,7 @@ class AllreduceMock : public AllreduceRobust {
int version; int version;
int seqno; int seqno;
int ntrial; int ntrial;
MockKey(void) {} MockKey() = default;
MockKey(int rank, int version, int seqno, int ntrial) MockKey(int rank, int version, int seqno, int ntrial)
: rank(rank), version(version), seqno(seqno), ntrial(ntrial) {} : rank(rank), version(version), seqno(seqno), ntrial(ntrial) {}
inline bool operator==(const MockKey &b) const { inline bool operator==(const MockKey &b) const {
@ -189,15 +178,15 @@ class AllreduceMock : public AllreduceRobust {
} }
}; };
// number of failure trials // number of failure trials
int num_trial; int num_trial_;
// record all mock actions // record all mock actions
std::map<MockKey, int> mock_map; std::map<MockKey, int> mock_map_;
// used to generate all kinds of exceptions // used to generate all kinds of exceptions
inline void Verify(const MockKey &key, const char *name) { inline void Verify(const MockKey &key, const char *name) {
if (mock_map.count(key) != 0) { if (mock_map_.count(key) != 0) {
num_trial += 1; num_trial_ += 1;
// data processing frameworks runs on shared process // data processing frameworks runs on shared process
_error("[%d]@@@Hit Mock Error:%s ", rank, name); error_("[%d]@@@Hit Mock Error:%s ", rank, name);
} }
} }
}; };

View File

@ -40,9 +40,9 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
const std::vector<EdgeType> &edge_in, const std::vector<EdgeType> &edge_in,
size_t out_index)) { size_t out_index)) {
RefLinkVector &links = tree_links; RefLinkVector &links = tree_links;
if (links.size() == 0) return kSuccess; if (links.Size() == 0) return kSuccess;
// number of links // number of links
const int nlink = static_cast<int>(links.size()); const int nlink = static_cast<int>(links.Size());
// initialize the pointers // initialize the pointers
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
links[i].ResetSize(); links[i].ResetSize();

File diff suppressed because it is too large Load Diff

View File

@ -22,18 +22,18 @@ namespace engine {
/*! \brief implementation of fault tolerant all reduce engine */ /*! \brief implementation of fault tolerant all reduce engine */
class AllreduceRobust : public AllreduceBase { class AllreduceRobust : public AllreduceBase {
public: public:
AllreduceRobust(void); AllreduceRobust();
virtual ~AllreduceRobust(void) {} ~AllreduceRobust() override = default;
// initialize the manager // initialize the manager
virtual bool Init(int argc, char* argv[]); bool Init(int argc, char* argv[]) override;
/*! \brief shutdown the engine */ /*! \brief shutdown the engine */
virtual bool Shutdown(void); bool Shutdown() override;
/*! /*!
* \brief set parameters to the engine * \brief set parameters to the engine
* \param name parameter name * \param name parameter name
* \param val parameter value * \param val parameter value
*/ */
virtual void SetParam(const char *name, const char *val); void SetParam(const char *name, const char *val) override;
/*! /*!
* \brief perform immutable local bootstrap cache insertion * \brief perform immutable local bootstrap cache insertion
* \param key unique cache key * \param key unique cache key
@ -67,13 +67,10 @@ class AllreduceRobust : public AllreduceBase {
* \param _line caller line number used to generate unique cache key * \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key * \param _caller caller function name used to generate unique cache key
*/ */
virtual void Allgather(void *sendrecvbuf_, size_t total_size, void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
size_t slice_begin, size_t slice_end, size_t size_prev_slice,
size_t slice_end, const char *_file = _FILE, const int _line = _LINE,
size_t size_prev_slice, const char *_caller = _CALLER) override;
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
/*! /*!
* \brief perform in-place allreduce, on sendrecvbuf * \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe * this function is NOT thread-safe
@ -90,15 +87,11 @@ class AllreduceRobust : public AllreduceBase {
* \param _line caller line number used to generate unique cache key * \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key * \param _caller caller function name used to generate unique cache key
*/ */
virtual void Allreduce(void *sendrecvbuf_, void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
size_t type_nbytes, ReduceFunction reducer, PreprocFunction prepare_fun = nullptr,
size_t count, void *prepare_arg = nullptr, const char *_file = _FILE,
ReduceFunction reducer, const int _line = _LINE,
PreprocFunction prepare_fun = NULL, const char *_caller = _CALLER) override;
void *prepare_arg = NULL,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
/*! /*!
* \brief broadcast data from root to all nodes * \brief broadcast data from root to all nodes
* \param sendrecvbuf_ buffer for both sending and recving data * \param sendrecvbuf_ buffer for both sending and recving data
@ -108,10 +101,9 @@ class AllreduceRobust : public AllreduceBase {
* \param _line caller line number used to generate unique cache key * \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key * \param _caller caller function name used to generate unique cache key
*/ */
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root, void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char* _file = _FILE, const char *_file = _FILE, const int _line = _LINE,
const int _line = _LINE, const char *_caller = _CALLER) override;
const char* _caller = _CALLER);
/*! /*!
* \brief load latest check point * \brief load latest check point
* \param global_model pointer to the globally shared model/state * \param global_model pointer to the globally shared model/state
@ -134,8 +126,8 @@ class AllreduceRobust : public AllreduceBase {
* *
* \sa CheckPoint, VersionNumber * \sa CheckPoint, VersionNumber
*/ */
virtual int LoadCheckPoint(Serializable *global_model, int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL); Serializable *local_model = nullptr) override;
/*! /*!
* \brief checkpoint the model, meaning we finished a stage of execution * \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one * every time we call check point, there is a version number which will increase by one
@ -152,9 +144,9 @@ class AllreduceRobust : public AllreduceBase {
* *
* \sa LoadCheckPoint, VersionNumber * \sa LoadCheckPoint, VersionNumber
*/ */
virtual void CheckPoint(const Serializable *global_model, void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL) { const Serializable *local_model = nullptr) override {
this->CheckPoint_(global_model, local_model, false); this->CheckPointImpl(global_model, local_model, false);
} }
/*! /*!
* \brief This function can be used to replace CheckPoint for global_model only, * \brief This function can be used to replace CheckPoint for global_model only,
@ -176,18 +168,20 @@ class AllreduceRobust : public AllreduceBase {
* is the same in all nodes * is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber * \sa LoadCheckPoint, CheckPoint, VersionNumber
*/ */
virtual void LazyCheckPoint(const Serializable *global_model) { void LazyCheckPoint(const Serializable *global_model) override {
this->CheckPoint_(global_model, NULL, true); this->CheckPointImpl(global_model, nullptr, true);
} }
/*! /*!
* \brief explicitly re-init everything before calling LoadCheckPoint * \brief explicitly re-init everything before calling LoadCheckPoint
* call this function when IEngine throw an exception out, * call this function when IEngine throw an exception out,
* this function is only used for test purpose * this function is only used for test purpose
*/ */
virtual void InitAfterException(void) { void InitAfterException() override {
// simple way, shutdown all links // simple way, shutdown all links
for (size_t i = 0; i < all_links.size(); ++i) { for (auto& link : all_links) {
if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close(); if (link.sock.BadSocket()) {
link.sock.Close();
}
} }
ReConnectLinks("recover"); ReConnectLinks("recover");
} }
@ -245,69 +239,69 @@ class AllreduceRobust : public AllreduceBase {
// there are nodes request load cache // there are nodes request load cache
static const int kLoadBootstrapCache = 16; static const int kLoadBootstrapCache = 16;
// constructor // constructor
ActionSummary(void) {} ActionSummary() = default;
// constructor of action // constructor of action
explicit ActionSummary(int seqno_flag, int cache_flag = 0, explicit ActionSummary(int seqno_flag, int cache_flag = 0,
u_int32_t minseqno = kSpecialOp, u_int32_t maxseqno = kSpecialOp) { u_int32_t minseqno = kSpecialOp, u_int32_t maxseqno = kSpecialOp) {
seqcode = (minseqno << 5) | seqno_flag; seqcode_ = (minseqno << 5) | seqno_flag;
maxseqcode = (maxseqno << 5) | cache_flag; maxseqcode_ = (maxseqno << 5) | cache_flag;
} }
// minimum number of all operations by default // minimum number of all operations by default
// maximum number of all cache operations otherwise // maximum number of all cache operations otherwise
inline u_int32_t seqno(SeqType t = SeqType::kSeq) const { inline u_int32_t Seqno(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode; int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
return code >> 5; return code >> 5;
} }
// whether the operation set contains a load_check // whether the operation set contains a load_check
inline bool load_check(SeqType t = SeqType::kSeq) const { inline bool LoadCheck(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode; int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
return (code & kLoadCheck) != 0; return (code & kLoadCheck) != 0;
} }
// whether the operation set contains a load_cache // whether the operation set contains a load_cache
inline bool load_cache(SeqType t = SeqType::kSeq) const { inline bool LoadCache(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode; int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
return (code & kLoadBootstrapCache) != 0; return (code & kLoadBootstrapCache) != 0;
} }
// whether the operation set contains a check point // whether the operation set contains a check point
inline bool check_point(SeqType t = SeqType::kSeq) const { inline bool CheckPoint(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode; int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
return (code & kCheckPoint) != 0; return (code & kCheckPoint) != 0;
} }
// whether the operation set contains a check ack // whether the operation set contains a check ack
inline bool check_ack(SeqType t = SeqType::kSeq) const { inline bool CheckAck(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode; int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
return (code & kCheckAck) != 0; return (code & kCheckAck) != 0;
} }
// whether the operation set contains different sequence number // whether the operation set contains different sequence number
inline bool diff_seq() const { inline bool DiffSeq() const {
return (seqcode & kDiffSeq) != 0; return (seqcode_ & kDiffSeq) != 0;
} }
// returns the operation flag of the result // returns the operation flag of the result
inline int flag(SeqType t = SeqType::kSeq) const { inline int Flag(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode; int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
return code & 31; return code & 31;
} }
// print flags in user friendly way // print flags in user friendly way
inline void print_flags(int rank, std::string prefix ) { inline void PrintFlags(int rank, std::string prefix ) {
utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n", utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n", rank,
rank, prefix.c_str(), prefix.c_str(), Seqno(), CheckPoint(), CheckAck(),
seqno(), check_point(), check_ack(), load_cache(), LoadCache(), DiffSeq(), Seqno(SeqType::kCache),
diff_seq(), seqno(SeqType::kCache), load_cache(SeqType::kCache)); LoadCache(SeqType::kCache));
} }
// reducer for Allreduce, get the result ActionSummary from all nodes // reducer for Allreduce, get the result ActionSummary from all nodes
inline static void Reducer(const void *src_, void *dst_, inline static void Reducer(const void *src_, void *dst_,
int len, const MPI::Datatype &dtype) { int len, const MPI::Datatype &dtype) {
const ActionSummary *src = (const ActionSummary*)src_; const ActionSummary *src = static_cast<const ActionSummary*>(src_);
ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_); ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_);
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
u_int32_t min_seqno = std::min(src[i].seqno(), dst[i].seqno()); u_int32_t min_seqno = std::min(src[i].Seqno(), dst[i].Seqno());
u_int32_t max_seqno = std::max(src[i].seqno(SeqType::kCache), u_int32_t max_seqno = std::max(src[i].Seqno(SeqType::kCache),
dst[i].seqno(SeqType::kCache)); dst[i].Seqno(SeqType::kCache));
int action_flag = src[i].flag() | dst[i].flag(); int action_flag = src[i].Flag() | dst[i].Flag();
// if any node is not requester set to 0 otherwise 1 // if any node is not requester set to 0 otherwise 1
int role_flag = src[i].flag(SeqType::kCache) & dst[i].flag(SeqType::kCache); int role_flag = src[i].Flag(SeqType::kCache) & dst[i].Flag(SeqType::kCache);
// if seqno is different in src and destination // if seqno is different in src and destination
int seq_diff_flag = src[i].seqno() != dst[i].seqno() ? kDiffSeq : 0; int seq_diff_flag = src[i].Seqno() != dst[i].Seqno() ? kDiffSeq : 0;
// apply or to both seq diff flag as well as cache seq diff flag // apply or to both seq diff flag as well as cache seq diff flag
dst[i] = ActionSummary(action_flag | seq_diff_flag, dst[i] = ActionSummary(action_flag | seq_diff_flag,
role_flag, min_seqno, max_seqno); role_flag, min_seqno, max_seqno);
@ -316,19 +310,19 @@ class AllreduceRobust : public AllreduceBase {
private: private:
// internel sequence code min of rabit seqno // internel sequence code min of rabit seqno
u_int32_t seqcode; u_int32_t seqcode_;
// internal sequence code max of cache seqno // internal sequence code max of cache seqno
u_int32_t maxseqcode; u_int32_t maxseqcode_;
}; };
/*! \brief data structure to remember result of Bcast and Allreduce calls*/ /*! \brief data structure to remember result of Bcast and Allreduce calls*/
class ResultBuffer{ class ResultBuffer{
public: public:
// constructor // constructor
ResultBuffer(void) { ResultBuffer() {
this->Clear(); this->Clear();
} }
// clear the existing record // clear the existing record
inline void Clear(void) { inline void Clear() {
seqno_.clear(); size_.clear(); seqno_.clear(); size_.clear();
rptr_.clear(); rptr_.push_back(0); rptr_.clear(); rptr_.push_back(0);
data_.clear(); data_.clear();
@ -358,12 +352,12 @@ class AllreduceRobust : public AllreduceBase {
inline void* Query(int seqid, size_t *p_size) { inline void* Query(int seqid, size_t *p_size) {
size_t idx = std::lower_bound(seqno_.begin(), size_t idx = std::lower_bound(seqno_.begin(),
seqno_.end(), seqid) - seqno_.begin(); seqno_.end(), seqid) - seqno_.begin();
if (idx == seqno_.size() || seqno_[idx] != seqid) return NULL; if (idx == seqno_.size() || seqno_[idx] != seqid) return nullptr;
*p_size = size_[idx]; *p_size = size_[idx];
return BeginPtr(data_) + rptr_[idx]; return BeginPtr(data_) + rptr_[idx];
} }
// drop last stored result // drop last stored result
inline void DropLast(void) { inline void DropLast() {
utils::Assert(seqno_.size() != 0, "there is nothing to be dropped"); utils::Assert(seqno_.size() != 0, "there is nothing to be dropped");
seqno_.pop_back(); seqno_.pop_back();
rptr_.pop_back(); rptr_.pop_back();
@ -371,7 +365,7 @@ class AllreduceRobust : public AllreduceBase {
data_.resize(rptr_.back()); data_.resize(rptr_.back());
} }
// the sequence number of last stored result // the sequence number of last stored result
inline int LastSeqNo(void) const { inline int LastSeqNo() const {
if (seqno_.size() == 0) return -1; if (seqno_.size() == 0) return -1;
return seqno_.back(); return seqno_.back();
} }
@ -407,9 +401,8 @@ class AllreduceRobust : public AllreduceBase {
* *
* \sa CheckPoint, LazyCheckPoint * \sa CheckPoint, LazyCheckPoint
*/ */
void CheckPoint_(const Serializable *global_model, void CheckPointImpl(const Serializable *global_model,
const Serializable *local_model, const Serializable *local_model, bool lazy_checkpt);
bool lazy_checkpt);
/*! /*!
* \brief reset the all the existing links by sending Out-of-Band message marker * \brief reset the all the existing links by sending Out-of-Band message marker
* after this function finishes, all the messages received and sent * after this function finishes, all the messages received and sent
@ -423,7 +416,7 @@ class AllreduceRobust : public AllreduceBase {
* when kSockError is returned, it simply means there are bad sockets in the links, * when kSockError is returned, it simply means there are bad sockets in the links,
* and some link recovery proceduer is needed * and some link recovery proceduer is needed
*/ */
ReturnType TryResetLinks(void); ReturnType TryResetLinks();
/*! /*!
* \brief if err_type indicates an error * \brief if err_type indicates an error
* recover links according to the error type reported * recover links according to the error type reported
@ -619,29 +612,29 @@ o * the input state must exactly one saved state(local state of current node)
size_t out_index)); size_t out_index));
//---- recovery data structure ---- //---- recovery data structure ----
// the round of result buffer, used to mode the result // the round of result buffer, used to mode the result
int result_buffer_round; int result_buffer_round_;
// result buffer of all reduce // result buffer of all reduce
ResultBuffer resbuf; ResultBuffer resbuf_;
// current cached allreduce/braodcast sequence number // current cached allreduce/braodcast sequence number
int cur_cache_seq; int cur_cache_seq_;
// result buffer of cached all reduce // result buffer of cached all reduce
ResultBuffer cachebuf; ResultBuffer cachebuf_;
// key of each cache entry // key of each cache entry
ResultBuffer lookupbuf; ResultBuffer lookupbuf_;
// last check point global model // last check point global model
std::string global_checkpoint; std::string global_checkpoint_;
// lazy checkpoint of global model // lazy checkpoint of global model
const Serializable *global_lazycheck; const Serializable *global_lazycheck_;
// number of replica for local state/model // number of replica for local state/model
int num_local_replica; int num_local_replica_;
// number of default local replica // number of default local replica
int default_local_replica; int default_local_replica_;
// flag to decide whether local model is used, -1: unknown, 0: no, 1:yes // flag to decide whether local model is used, -1: unknown, 0: no, 1:yes
int use_local_model; int use_local_model_;
// number of replica for global state/model // number of replica for global state/model
int num_global_replica; int num_global_replica_;
// number of times recovery happens // number of times recovery happens
int recover_counter; int recover_counter_;
// --- recovery data structure for local checkpoint // --- recovery data structure for local checkpoint
// there is two version of the data structure, // there is two version of the data structure,
// at one time one version is valid and another is used as temp memory // at one time one version is valid and another is used as temp memory
@ -649,21 +642,21 @@ o * the input state must exactly one saved state(local state of current node)
// local model is stored in CSR format(like a sparse matrices) // local model is stored in CSR format(like a sparse matrices)
// local_model[rptr[0]:rptr[1]] stores the model of current node // local_model[rptr[0]:rptr[1]] stores the model of current node
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops // local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops
std::vector<size_t> local_rptr[2]; std::vector<size_t> local_rptr_[2];
// storage for local model replicas // storage for local model replicas
std::string local_chkpt[2]; std::string local_chkpt_[2];
// version of local checkpoint can be 1 or 0 // version of local checkpoint can be 1 or 0
int local_chkpt_version; int local_chkpt_version_;
// if checkpoint were loaded, used to distinguish results boostrap cache from seqno cache // if checkpoint were loaded, used to distinguish results boostrap cache from seqno cache
bool checkpoint_loaded; bool checkpoint_loaded_;
// sidecar executing timeout task // sidecar executing timeout task
std::future<bool> rabit_timeout_task; std::future<bool> rabit_timeout_task_;
// flag to shutdown rabit_timeout_task before timeout // flag to shutdown rabit_timeout_task before timeout
std::atomic<bool> shutdown_timeout{false}; std::atomic<bool> shutdown_timeout_{false};
// error handler // error handler
void (* _error)(const char *fmt, ...) = utils::Error; void (* error_)(const char *fmt, ...) = utils::Error;
// assert handler // assert handler
void (* _assert)(bool exp, const char *fmt, ...) = utils::Assert; void (* assert_)(bool exp, const char *fmt, ...) = utils::Assert;
}; };
} // namespace engine } // namespace engine
} // namespace rabit } // namespace rabit

View File

@ -33,12 +33,12 @@ struct FHelper<op::BitOR, DType> {
}; };
template<typename OP> template<typename OP>
void Allreduce_(void *sendrecvbuf_, void Allreduce(void *sendrecvbuf_,
size_t count, size_t count,
engine::mpi::DataType enum_dtype, engine::mpi::DataType enum_dtype,
void (*prepare_fun)(void *arg), void (*prepare_fun)(void *arg),
void *prepare_arg) { void *prepare_arg) {
using namespace engine::mpi; using namespace engine::mpi; // NOLINT
switch (enum_dtype) { switch (enum_dtype) {
case kChar: case kChar:
rabit::Allreduce<OP> rabit::Allreduce<OP>
@ -89,28 +89,28 @@ void Allreduce(void *sendrecvbuf,
engine::mpi::OpType enum_op, engine::mpi::OpType enum_op,
void (*prepare_fun)(void *arg), void (*prepare_fun)(void *arg),
void *prepare_arg) { void *prepare_arg) {
using namespace engine::mpi; using namespace engine::mpi; // NOLINT
switch (enum_op) { switch (enum_op) {
case kMax: case kMax:
Allreduce_<op::Max> Allreduce<op::Max>
(sendrecvbuf, (sendrecvbuf,
count, enum_dtype, count, enum_dtype,
prepare_fun, prepare_arg); prepare_fun, prepare_arg);
return; return;
case kMin: case kMin:
Allreduce_<op::Min> Allreduce<op::Min>
(sendrecvbuf, (sendrecvbuf,
count, enum_dtype, count, enum_dtype,
prepare_fun, prepare_arg); prepare_fun, prepare_arg);
return; return;
case kSum: case kSum:
Allreduce_<op::Sum> Allreduce<op::Sum>
(sendrecvbuf, (sendrecvbuf,
count, enum_dtype, count, enum_dtype,
prepare_fun, prepare_arg); prepare_fun, prepare_arg);
return; return;
case kBitwiseOR: case kBitwiseOR:
Allreduce_<op::BitOR> Allreduce<op::BitOR>
(sendrecvbuf, (sendrecvbuf,
count, enum_dtype, count, enum_dtype,
prepare_fun, prepare_arg); prepare_fun, prepare_arg);
@ -124,7 +124,7 @@ void Allgather(void *sendrecvbuf_,
size_t size_node_slice, size_t size_node_slice,
size_t size_prev_slice, size_t size_prev_slice,
int enum_dtype) { int enum_dtype) {
using namespace engine::mpi; using namespace engine::mpi; // NOLINT
size_t type_size = 0; size_t type_size = 0;
switch (enum_dtype) { switch (enum_dtype) {
case kChar: case kChar:
@ -184,7 +184,7 @@ struct ReadWrapper : public Serializable {
std::string *p_str; std::string *p_str;
explicit ReadWrapper(std::string *p_str) explicit ReadWrapper(std::string *p_str)
: p_str(p_str) {} : p_str(p_str) {}
virtual void Load(Stream *fi) { void Load(Stream *fi) override {
uint64_t sz; uint64_t sz;
utils::Assert(fi->Read(&sz, sizeof(sz)) != 0, utils::Assert(fi->Read(&sz, sizeof(sz)) != 0,
"Read pickle string"); "Read pickle string");
@ -194,7 +194,7 @@ struct ReadWrapper : public Serializable {
"Read pickle string"); "Read pickle string");
} }
} }
virtual void Save(Stream *fo) const { void Save(Stream *fo) const override {
utils::Error("not implemented"); utils::Error("not implemented");
} }
}; };
@ -206,10 +206,10 @@ struct WriteWrapper : public Serializable {
size_t length) size_t length)
: data(data), length(length) { : data(data), length(length) {
} }
virtual void Load(Stream *fi) { void Load(Stream *fi) override {
utils::Error("not implemented"); utils::Error("not implemented");
} }
virtual void Save(Stream *fo) const { void Save(Stream *fo) const override {
uint64_t sz = static_cast<uint16_t>(length); uint64_t sz = static_cast<uint16_t>(length);
fo->Write(&sz, sizeof(sz)); fo->Write(&sz, sizeof(sz));
fo->Write(data, length * sizeof(char)); fo->Write(data, length * sizeof(char));
@ -298,8 +298,8 @@ RABIT_DLL int RabitLoadCheckPoint(char **out_global_model,
ReadWrapper sl(&local_buffer); ReadWrapper sl(&local_buffer);
int version; int version;
if (out_local_model == NULL) { if (out_local_model == nullptr) {
version = rabit::LoadCheckPoint(&sg, NULL); version = rabit::LoadCheckPoint(&sg, nullptr);
*out_global_model = BeginPtr(global_buffer); *out_global_model = BeginPtr(global_buffer);
*out_global_len = static_cast<rbt_ulong>(global_buffer.length()); *out_global_len = static_cast<rbt_ulong>(global_buffer.length());
} else { } else {
@ -317,8 +317,8 @@ RABIT_DLL void RabitCheckPoint(const char *global_model, rbt_ulong global_len,
using namespace rabit::c_api; // NOLINT(*) using namespace rabit::c_api; // NOLINT(*)
WriteWrapper sg(global_model, global_len); WriteWrapper sg(global_model, global_len);
WriteWrapper sl(local_model, local_len); WriteWrapper sl(local_model, local_len);
if (local_model == NULL) { if (local_model == nullptr) {
rabit::CheckPoint(&sg, NULL); rabit::CheckPoint(&sg, nullptr);
} else { } else {
rabit::CheckPoint(&sg, &sl); rabit::CheckPoint(&sg, &sl);
} }

View File

@ -7,18 +7,19 @@
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/ */
#include <rabit/base.h> #include <rabit/base.h>
#include <dmlc/thread_local.h>
#include <memory> #include <memory>
#include "rabit/internal/engine.h" #include "rabit/internal/engine.h"
#include "allreduce_base.h" #include "allreduce_base.h"
#include "allreduce_robust.h" #include "allreduce_robust.h"
#include "rabit/internal/thread_local.h"
namespace rabit { namespace rabit {
namespace engine { namespace engine {
// singleton sync manager // singleton sync manager
#ifndef RABIT_USE_BASE #ifndef RABIT_USE_BASE
#ifndef RABIT_USE_MOCK #ifndef RABIT_USE_MOCK
typedef AllreduceRobust Manager; using Manager = AllreduceRobust;
#else #else
typedef AllreduceMock Manager; typedef AllreduceMock Manager;
#endif // RABIT_USE_MOCK #endif // RABIT_USE_MOCK
@ -31,13 +32,13 @@ struct ThreadLocalEntry {
/*! \brief stores the current engine */ /*! \brief stores the current engine */
std::unique_ptr<Manager> engine; std::unique_ptr<Manager> engine;
/*! \brief whether init has been called */ /*! \brief whether init has been called */
bool initialized; bool initialized{false};
/*! \brief constructor */ /*! \brief constructor */
ThreadLocalEntry() : initialized(false) {} ThreadLocalEntry() = default;
}; };
// define the threadlocal store. // define the threadlocal store.
typedef ThreadLocalStore<ThreadLocalEntry> EngineThreadLocal; using EngineThreadLocal = dmlc::ThreadLocalStore<ThreadLocalEntry>;
/*! \brief intiialize the synchronization module */ /*! \brief intiialize the synchronization module */
bool Init(int argc, char *argv[]) { bool Init(int argc, char *argv[]) {
@ -95,7 +96,7 @@ void Allgather(void *sendrecvbuf_, size_t total_size,
// perform in-place allreduce, on sendrecvbuf // perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf, void Allreduce_(void *sendrecvbuf, // NOLINT
size_t type_nbytes, size_t type_nbytes,
size_t count, size_t count,
IEngine::ReduceFunction red, IEngine::ReduceFunction red,
@ -111,18 +112,15 @@ void Allreduce_(void *sendrecvbuf,
} }
// code for reduce handle // code for reduce handle
ReduceHandle::ReduceHandle(void) ReduceHandle::ReduceHandle() = default;
: handle_(NULL), redfunc_(NULL), htype_(NULL) { ReduceHandle::~ReduceHandle() = default;
}
ReduceHandle::~ReduceHandle(void) {}
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) { int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return static_cast<int>(dtype.type_size); return static_cast<int>(dtype.type_size);
} }
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) { void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
utils::Assert(redfunc_ == NULL, "cannot initialize reduce handle twice"); utils::Assert(redfunc_ == nullptr, "cannot initialize reduce handle twice");
redfunc_ = redfunc; redfunc_ = redfunc;
} }
@ -133,7 +131,7 @@ void ReduceHandle::Allreduce(void *sendrecvbuf,
const char* _file, const char* _file,
const int _line, const int _line,
const char* _caller) { const char* _caller) {
utils::Assert(redfunc_ != NULL, "must intialize handle to call AllReduce"); utils::Assert(redfunc_ != nullptr, "must intialize handle to call AllReduce");
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
redfunc_, prepare_fun, prepare_arg, redfunc_, prepare_fun, prepare_arg,
_file, _line, _caller); _file, _line, _caller);

View File

@ -16,78 +16,68 @@ namespace engine {
/*! \brief EmptyEngine */ /*! \brief EmptyEngine */
class EmptyEngine : public IEngine { class EmptyEngine : public IEngine {
public: public:
EmptyEngine(void) { EmptyEngine() {
version_number = 0; version_number_ = 0;
} }
virtual void Allgather(void *sendrecvbuf_, void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
size_t total_size, size_t slice_end, size_t size_prev_slice, const char *_file,
size_t slice_begin, const int _line, const char *_caller) override {
size_t slice_end,
size_t size_prev_slice,
const char* _file,
const int _line,
const char* _caller) {
utils::Error("EmptyEngine:: Allgather is not supported"); utils::Error("EmptyEngine:: Allgather is not supported");
} }
virtual int GetRingPrevRank(void) const { int GetRingPrevRank() const override {
utils::Error("EmptyEngine:: GetRingPrevRank is not supported"); utils::Error("EmptyEngine:: GetRingPrevRank is not supported");
return -1; return -1;
} }
virtual void Allreduce(void *sendrecvbuf_, void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
size_t type_nbytes, ReduceFunction reducer, PreprocFunction prepare_fun,
size_t count, void *prepare_arg, const char *_file, const int _line,
ReduceFunction reducer, const char *_caller) override {
PreprocFunction prepare_fun,
void *prepare_arg,
const char* _file,
const int _line,
const char* _caller) {
utils::Error("EmptyEngine:: Allreduce is not supported,"\ utils::Error("EmptyEngine:: Allreduce is not supported,"\
"use Allreduce_ instead"); "use Allreduce_ instead");
} }
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root, void Broadcast(void *sendrecvbuf_, size_t size, int root,
const char* _file, const int _line, const char* _caller) { const char* _file, const int _line, const char* _caller) override {
} }
virtual void InitAfterException(void) { void InitAfterException() override {
utils::Error("EmptyEngine is not fault tolerant"); utils::Error("EmptyEngine is not fault tolerant");
} }
virtual int LoadCheckPoint(Serializable *global_model, int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL) { Serializable *local_model = nullptr) override {
return 0; return 0;
} }
virtual void CheckPoint(const Serializable *global_model, void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL) { const Serializable *local_model = nullptr) override {
version_number += 1; version_number_ += 1;
} }
virtual void LazyCheckPoint(const Serializable *global_model) { void LazyCheckPoint(const Serializable *global_model) override {
version_number += 1; version_number_ += 1;
} }
virtual int VersionNumber(void) const { int VersionNumber() const override {
return version_number; return version_number_;
} }
/*! \brief get rank of current node */ /*! \brief get rank of current node */
virtual int GetRank(void) const { int GetRank() const override {
return 0; return 0;
} }
/*! \brief get total number of */ /*! \brief get total number of */
virtual int GetWorldSize(void) const { int GetWorldSize() const override {
return 1; return 1;
} }
/*! \brief whether it is distributed */ /*! \brief whether it is distributed */
virtual bool IsDistributed(void) const { bool IsDistributed() const override {
return false; return false;
} }
/*! \brief get the host name of current node */ /*! \brief get the host name of current node */
virtual std::string GetHost(void) const { std::string GetHost() const override {
return std::string(""); return std::string("");
} }
virtual void TrackerPrint(const std::string &msg) { void TrackerPrint(const std::string &msg) override {
// simply print information into the tracker // simply print information into the tracker
utils::Printf("%s", msg.c_str()); utils::Printf("%s", msg.c_str());
} }
private: private:
int version_number; int version_number_;
}; };
// singleton sync manager // singleton sync manager
@ -98,12 +88,12 @@ bool Init(int argc, char *argv[]) {
return true; return true;
} }
/*! \brief finalize syncrhonization module */ /*! \brief finalize syncrhonization module */
bool Finalize(void) { bool Finalize() {
return true; return true;
} }
/*! \brief singleton method to get engine */ /*! \brief singleton method to get engine */
IEngine *GetEngine(void) { IEngine *GetEngine() {
return &manager; return &manager;
} }
// perform in-place allreduce, on sendrecvbuf // perform in-place allreduce, on sendrecvbuf
@ -118,13 +108,12 @@ void Allreduce_(void *sendrecvbuf,
const char* _file, const char* _file,
const int _line, const int _line,
const char* _caller) { const char* _caller) {
if (prepare_fun != NULL) prepare_fun(prepare_arg); if (prepare_fun != nullptr) prepare_fun(prepare_arg);
} }
// code for reduce handle // code for reduce handle
ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) { ReduceHandle::ReduceHandle() = default;
} ReduceHandle::~ReduceHandle() = default;
ReduceHandle::~ReduceHandle(void) {}
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) { int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return 0; return 0;
@ -137,7 +126,7 @@ void ReduceHandle::Allreduce(void *sendrecvbuf,
const char* _file, const char* _file,
const int _line, const int _line,
const char* _caller) { const char* _caller) {
if (prepare_fun != NULL) prepare_fun(prepare_arg); if (prepare_fun != nullptr) prepare_fun(prepare_arg);
} }
} // namespace engine } // namespace engine
} // namespace rabit } // namespace rabit

View File

@ -1,7 +1,7 @@
/*! /*!
* Copyright (c) 2014 by Contributors * Copyright (c) 2014 by Contributors
* \file engine_mock.cc * \file engine_mock.cc
* \brief this is an engine implementation that will * \brief this is an engine implementation that will
* insert failures in certain call point, to test if the engine is robust to failure * insert failures in certain call point, to test if the engine is robust to failure
* \author Tianqi Chen * \author Tianqi Chen
*/ */
@ -12,4 +12,3 @@
#include <rabit/base.h> #include <rabit/base.h>
#include "allreduce_mock.h" #include "allreduce_mock.h"
#include "engine.cc" #include "engine.cc"