From b0001a6e29a8c4423953eb00b33293a9868f6460 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 8 Sep 2020 12:13:58 +0800 Subject: [PATCH] Correct style warnings from clang-tidy for rabit. (#6095) --- rabit/include/rabit/internal/engine.h | 56 +-- rabit/include/rabit/internal/io.h | 30 +- rabit/include/rabit/internal/rabit-inl.h | 68 +-- rabit/include/rabit/internal/socket.h | 60 +-- rabit/include/rabit/internal/thread_local.h | 87 ---- rabit/include/rabit/internal/timer.h | 4 +- rabit/include/rabit/internal/utils.h | 22 +- rabit/include/rabit/rabit.h | 22 +- rabit/include/rabit/serializable.h | 4 +- rabit/src/allreduce_base.cc | 122 ++--- rabit/src/allreduce_base.h | 142 +++--- rabit/src/allreduce_mock.h | 153 +++--- rabit/src/allreduce_robust-inl.h | 4 +- rabit/src/allreduce_robust.cc | 509 ++++++++++---------- rabit/src/allreduce_robust.h | 183 ++++--- rabit/src/c_api.cc | 32 +- rabit/src/engine.cc | 24 +- rabit/src/engine_empty.cc | 79 ++- rabit/src/engine_mock.cc | 3 +- 19 files changed, 736 insertions(+), 868 deletions(-) delete mode 100644 rabit/include/rabit/internal/thread_local.h diff --git a/rabit/include/rabit/internal/engine.h b/rabit/include/rabit/internal/engine.h index 0db10e7f0..8cb03f35f 100644 --- a/rabit/include/rabit/internal/engine.h +++ b/rabit/include/rabit/internal/engine.h @@ -19,7 +19,7 @@ #define _CALLER "N/A" #endif // (defined(__GNUC__) && !defined(__clang__)) -namespace MPI { +namespace MPI { // NOLINT /*! \brief MPI data type just to be compatible with MPI reduce function*/ class Datatype; } @@ -36,7 +36,7 @@ class IEngine { * used to prepare the data used by AllReduce * \param arg additional possible argument used to invoke the preprocessor */ - typedef void (PreprocFunction) (void *arg); + typedef void (PreprocFunction) (void *arg); // NOLINT /*! * \brief reduce function, the same form of MPI reduce function is used, * to be compatible with MPI interface @@ -48,11 +48,11 @@ class IEngine { * the definition of the reduce function should be type aware * \param dtype the data type object, to be compatible with MPI reduce */ - typedef void (ReduceFunction) (const void *src, + typedef void (ReduceFunction) (const void *src, // NOLINT void *dst, int count, const MPI::Datatype &dtype); /*! \brief virtual destructor */ - virtual ~IEngine() {} + ~IEngine() = default; /*! * \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf, * the data provided by current node k is [slice_begin, slice_end), @@ -68,7 +68,7 @@ class IEngine { * \param _file caller file name used to generate unique cache key * \param _line caller line number used to generate unique cache key * \param _caller caller function name used to generate unique cache key - */ + */ virtual void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, @@ -96,8 +96,8 @@ class IEngine { size_t type_nbytes, size_t count, ReduceFunction reducer, - PreprocFunction prepare_fun = NULL, - void *prepare_arg = NULL, + PreprocFunction prepare_fun = nullptr, + void *prepare_arg = nullptr, const char* _file = _FILE, const int _line = _LINE, const char* _caller = _CALLER) = 0; @@ -119,7 +119,7 @@ class IEngine { * call this function when IEngine throws an exception, * this function should only be used for test purposes */ - virtual void InitAfterException(void) = 0; + virtual void InitAfterException() = 0; /*! * \brief loads the latest check point * \param global_model pointer to the globally shared model/state @@ -143,7 +143,7 @@ class IEngine { * \sa CheckPoint, VersionNumber */ virtual int LoadCheckPoint(Serializable *global_model, - Serializable *local_model = NULL) = 0; + Serializable *local_model = nullptr) = 0; /*! * \brief checkpoints the model, meaning a stage of execution was finished * every time we call check point, a version number increases by ones @@ -161,7 +161,7 @@ class IEngine { * \sa LoadCheckPoint, VersionNumber */ virtual void CheckPoint(const Serializable *global_model, - const Serializable *local_model = NULL) = 0; + const Serializable *local_model = nullptr) = 0; /*! * \brief This function can be used to replace CheckPoint for global_model only, * when certain condition is met (see detailed explanation). @@ -188,17 +188,17 @@ class IEngine { * which means how many calls to CheckPoint we made so far * \sa LoadCheckPoint, CheckPoint */ - virtual int VersionNumber(void) const = 0; + virtual int VersionNumber() const = 0; /*! \brief gets rank of previous node in ring topology */ - virtual int GetRingPrevRank(void) const = 0; + virtual int GetRingPrevRank() const = 0; /*! \brief gets rank of current node */ - virtual int GetRank(void) const = 0; + virtual int GetRank() const = 0; /*! \brief gets total number of nodes */ - virtual int GetWorldSize(void) const = 0; + virtual int GetWorldSize() const = 0; /*! \brief whether we run in distribted mode */ - virtual bool IsDistributed(void) const = 0; + virtual bool IsDistributed() const = 0; /*! \brief gets the host name of the current node */ - virtual std::string GetHost(void) const = 0; + virtual std::string GetHost() const = 0; /*! * \brief prints the msg in the tracker, * this function can be used to communicate progress information to @@ -211,9 +211,9 @@ class IEngine { /*! \brief initializes the engine module */ bool Init(int argc, char *argv[]); /*! \brief finalizes the engine module */ -bool Finalize(void); +bool Finalize(); /*! \brief singleton method to get engine */ -IEngine *GetEngine(void); +IEngine *GetEngine(); /*! \brief namespace that contains stubs to be compatible with MPI */ namespace mpi { @@ -280,14 +280,14 @@ void Allgather(void* sendrecvbuf, * \param _line caller line number used to generate unique cache key * \param _caller caller function name used to generate unique cache key */ -void Allreduce_(void *sendrecvbuf, +void Allreduce_(void *sendrecvbuf, // NOLINT size_t type_nbytes, size_t count, IEngine::ReduceFunction red, mpi::DataType dtype, mpi::OpType op, - IEngine::PreprocFunction prepare_fun = NULL, - void *prepare_arg = NULL, + IEngine::PreprocFunction prepare_fun = nullptr, + void *prepare_arg = nullptr, const char* _file = _FILE, const int _line = _LINE, const char* _caller = _CALLER); @@ -298,9 +298,9 @@ void Allreduce_(void *sendrecvbuf, class ReduceHandle { public: // constructor - ReduceHandle(void); + ReduceHandle(); // destructor - ~ReduceHandle(void); + ~ReduceHandle(); /*! * \brief initialize the reduce function, * with the type the reduce function needs to deal with @@ -323,8 +323,8 @@ class ReduceHandle { void Allreduce(void *sendrecvbuf, size_t type_nbytes, size_t count, - IEngine::PreprocFunction prepare_fun = NULL, - void *prepare_arg = NULL, + IEngine::PreprocFunction prepare_fun = nullptr, + void *prepare_arg = nullptr, const char* _file = _FILE, const int _line = _LINE, const char* _caller = _CALLER); @@ -333,11 +333,11 @@ class ReduceHandle { protected: // handle function field - void *handle_; + void *handle_ {nullptr}; // reduce function of the reducer - IEngine::ReduceFunction *redfunc_; + IEngine::ReduceFunction *redfunc_{nullptr}; // handle to the type field - void *htype_; + void *htype_{nullptr}; // the created type in 4 bytes size_t created_type_nbytes_; }; diff --git a/rabit/include/rabit/internal/io.h b/rabit/include/rabit/internal/io.h index f7e255b45..978eebd8a 100644 --- a/rabit/include/rabit/internal/io.h +++ b/rabit/include/rabit/internal/io.h @@ -19,12 +19,12 @@ namespace rabit { namespace utils { /*! \brief re-use definition of dmlc::SeekStream */ -typedef dmlc::SeekStream SeekStream; +using SeekStream = dmlc::SeekStream; /*! \brief fixed size memory buffer */ struct MemoryFixSizeBuffer : public SeekStream { public: // similar to SEEK_END in libc - static size_t constexpr SeekEnd = std::numeric_limits::max(); + static size_t constexpr kSeekEnd = std::numeric_limits::max(); public: MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size) @@ -32,31 +32,31 @@ struct MemoryFixSizeBuffer : public SeekStream { buffer_size_(buffer_size) { curr_ptr_ = 0; } - virtual ~MemoryFixSizeBuffer(void) {} - virtual size_t Read(void *ptr, size_t size) { + ~MemoryFixSizeBuffer() override = default; + size_t Read(void *ptr, size_t size) override { size_t nread = std::min(buffer_size_ - curr_ptr_, size); if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread); curr_ptr_ += nread; return nread; } - virtual void Write(const void *ptr, size_t size) { + void Write(const void *ptr, size_t size) override { if (size == 0) return; utils::Assert(curr_ptr_ + size <= buffer_size_, "write position exceed fixed buffer size"); std::memcpy(p_buffer_ + curr_ptr_, ptr, size); curr_ptr_ += size; } - virtual void Seek(size_t pos) { - if (pos == SeekEnd) { + void Seek(size_t pos) override { + if (pos == kSeekEnd) { curr_ptr_ = buffer_size_; } else { curr_ptr_ = static_cast(pos); } } - virtual size_t Tell(void) { + size_t Tell() override { return curr_ptr_; } - virtual bool AtEnd(void) const { + virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; } @@ -76,8 +76,8 @@ struct MemoryBufferStream : public SeekStream { : p_buffer_(p_buffer) { curr_ptr_ = 0; } - virtual ~MemoryBufferStream(void) {} - virtual size_t Read(void *ptr, size_t size) { + ~MemoryBufferStream() override = default; + size_t Read(void *ptr, size_t size) override { utils::Assert(curr_ptr_ <= p_buffer_->length(), "read can not have position excceed buffer length"); size_t nread = std::min(p_buffer_->length() - curr_ptr_, size); @@ -85,7 +85,7 @@ struct MemoryBufferStream : public SeekStream { curr_ptr_ += nread; return nread; } - virtual void Write(const void *ptr, size_t size) { + void Write(const void *ptr, size_t size) override { if (size == 0) return; if (curr_ptr_ + size > p_buffer_->length()) { p_buffer_->resize(curr_ptr_+size); @@ -93,13 +93,13 @@ struct MemoryBufferStream : public SeekStream { std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size); curr_ptr_ += size; } - virtual void Seek(size_t pos) { + void Seek(size_t pos) override { curr_ptr_ = static_cast(pos); } - virtual size_t Tell(void) { + size_t Tell() override { return curr_ptr_; } - virtual bool AtEnd(void) const { + virtual bool AtEnd() const { return curr_ptr_ == p_buffer_->length(); } diff --git a/rabit/include/rabit/internal/rabit-inl.h b/rabit/include/rabit/internal/rabit-inl.h index 8ae604c4c..fa6902f84 100644 --- a/rabit/include/rabit/internal/rabit-inl.h +++ b/rabit/include/rabit/internal/rabit-inl.h @@ -19,45 +19,45 @@ namespace engine { namespace mpi { // template function to translate type to enum indicator template -inline DataType GetType(void); +inline DataType GetType(); template<> -inline DataType GetType(void) { +inline DataType GetType() { return kChar; } template<> -inline DataType GetType(void) { +inline DataType GetType() { return kUChar; } template<> -inline DataType GetType(void) { +inline DataType GetType() { return kInt; } template<> -inline DataType GetType(void) { // NOLINT(*) +inline DataType GetType() { // NOLINT(*) return kUInt; } template<> -inline DataType GetType(void) { // NOLINT(*) +inline DataType GetType() { // NOLINT(*) return kLong; } template<> -inline DataType GetType(void) { // NOLINT(*) +inline DataType GetType() { // NOLINT(*) return kULong; } template<> -inline DataType GetType(void) { +inline DataType GetType() { return kFloat; } template<> -inline DataType GetType(void) { +inline DataType GetType() { return kDouble; } template<> -inline DataType GetType(void) { // NOLINT(*) +inline DataType GetType() { // NOLINT(*) return kLongLong; } template<> -inline DataType GetType(void) { // NOLINT(*) +inline DataType GetType() { // NOLINT(*) return kULongLong; } } // namespace mpi @@ -94,7 +94,7 @@ struct BitOR { }; template inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) { - const DType* src = (const DType*)src_; + const DType* src = static_cast(src_); DType* dst = (DType*)dst_; // NOLINT(*) for (int i = 0; i < len; i++) { OP::Reduce(dst[i], src[i]); @@ -107,27 +107,27 @@ inline bool Init(int argc, char *argv[]) { return engine::Init(argc, argv); } // finalize the rabit engine -inline bool Finalize(void) { +inline bool Finalize() { return engine::Finalize(); } // get the rank of the previous worker in ring topology -inline int GetRingPrevRank(void) { +inline int GetRingPrevRank() { return engine::GetEngine()->GetRingPrevRank(); } // get the rank of current process -inline int GetRank(void) { +inline int GetRank() { return engine::GetEngine()->GetRank(); } // the the size of the world -inline int GetWorldSize(void) { +inline int GetWorldSize() { return engine::GetEngine()->GetWorldSize(); } // whether rabit is distributed -inline bool IsDistributed(void) { +inline bool IsDistributed() { return engine::GetEngine()->IsDistributed(); } // get the name of current processor -inline std::string GetProcessorName(void) { +inline std::string GetProcessorName() { return engine::GetEngine()->GetHost(); } // broadcast data to all other nodes from root @@ -183,7 +183,7 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, // C++11 support for lambda prepare function #if DMLC_USE_CXX11 -inline void InvokeLambda_(void *fun) { +inline void InvokeLambda(void *fun) { (*static_cast*>(fun))(); } template @@ -193,7 +193,7 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, const int _line, const char* _caller) { engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer, - engine::mpi::GetType(), OP::kType, InvokeLambda_, &prepare_fun, + engine::mpi::GetType(), OP::kType, InvokeLambda, &prepare_fun, _file, _line, _caller); } @@ -245,7 +245,7 @@ inline void LazyCheckPoint(const Serializable *global_model) { engine::GetEngine()->LazyCheckPoint(global_model); } // return the version number of currently stored model -inline int VersionNumber(void) { +inline int VersionNumber() { return engine::GetEngine()->VersionNumber(); } // --------------------------------- @@ -253,7 +253,7 @@ inline int VersionNumber(void) { // --------------------------------- // function to perform reduction for Reducer template -inline void ReducerSafe_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) { +inline void ReducerSafeImpl(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) { const size_t kUnit = sizeof(DType); const char *psrc = reinterpret_cast(src_); char *pdst = reinterpret_cast(dst_); @@ -269,7 +269,7 @@ inline void ReducerSafe_(const void *src_, void *dst_, int len_, const MPI::Data } // function to perform reduction for Reducer template // NOLINT(*) -inline void ReducerAlign_(const void *src_, void *dst_, +inline void ReducerAlignImpl(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) { const DType *psrc = reinterpret_cast(src_); DType *pdst = reinterpret_cast(dst_); @@ -278,12 +278,12 @@ inline void ReducerAlign_(const void *src_, void *dst_, } } template // NOLINT(*) -inline Reducer::Reducer(void) { +inline Reducer::Reducer() { // it is safe to directly use handle for aligned data types if (sizeof(DType) == 8 || sizeof(DType) == 4 || sizeof(DType) == 1) { - this->handle_.Init(ReducerAlign_, sizeof(DType)); + this->handle_.Init(ReducerAlignImpl, sizeof(DType)); } else { - this->handle_.Init(ReducerSafe_, sizeof(DType)); + this->handle_.Init(ReducerSafeImpl, sizeof(DType)); } } template // NOLINT(*) @@ -298,8 +298,8 @@ inline void Reducer::Allreduce(DType *sendrecvbuf, size_t count, } // function to perform reduction for SerializeReducer template -inline void SerializeReducerFunc_(const void *src_, void *dst_, - int len_, const MPI::Datatype &dtype) { +inline void SerializeReducerFuncImpl(const void *src_, void *dst_, + int len_, const MPI::Datatype &dtype) { int nbytes = engine::ReduceHandle::TypeSize(dtype); // temp space for (int i = 0; i < len_; ++i) { @@ -315,8 +315,8 @@ inline void SerializeReducerFunc_(const void *src_, void *dst_, } } template -inline SerializeReducer::SerializeReducer(void) { - handle_.Init(SerializeReducerFunc_, sizeof(DType)); +inline SerializeReducer::SerializeReducer() { + handle_.Init(SerializeReducerFuncImpl, sizeof(DType)); } // closure to call Allreduce template @@ -327,8 +327,8 @@ struct SerializeReduceClosure { void *prepare_arg; std::string *p_buffer; // invoke the closure - inline void Run(void) { - if (prepare_fun != NULL) prepare_fun(prepare_arg); + inline void Run() { + if (prepare_fun != nullptr) prepare_fun(prepare_arg); for (size_t i = 0; i < count; ++i) { utils::MemoryFixSizeBuffer fs(BeginPtr(*p_buffer) + i * max_nbyte, max_nbyte); sendrecvobj[i].Save(fs); @@ -368,7 +368,7 @@ inline void Reducer::Allreduce(DType *sendrecvbuf, size_t count, const char* _file, const int _line, const char* _caller) { - this->Allreduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun, + this->Allreduce(sendrecvbuf, count, InvokeLambda, &prepare_fun, _file, _line, _caller); } template @@ -378,7 +378,7 @@ inline void SerializeReducer::Allreduce(DType *sendrecvobj, const char* _file, const int _line, const char* _caller) { - this->Allreduce(sendrecvobj, max_nbytes, count, InvokeLambda_, &prepare_fun, + this->Allreduce(sendrecvobj, max_nbytes, count, InvokeLambda, &prepare_fun, _file, _line, _caller); } #endif // DMLC_USE_CXX11 diff --git a/rabit/include/rabit/internal/socket.h b/rabit/include/rabit/internal/socket.h index a80c4bc01..da48d6262 100644 --- a/rabit/include/rabit/internal/socket.h +++ b/rabit/include/rabit/internal/socket.h @@ -15,7 +15,7 @@ #else #include #include -#include +#include #include #include #include @@ -39,9 +39,9 @@ static inline int poll(struct pollfd *pfd, int nfds, int timeout) { return WSAPoll ( pfd, nfds, timeout ); } #else #include -typedef int SOCKET; -typedef size_t sock_size_t; -const int INVALID_SOCKET = -1; +using SOCKET = int; +using sock_size_t = size_t; // NOLINT +const int kInvalidSocket = -1; #endif // defined(_WIN32) namespace rabit { @@ -50,11 +50,11 @@ namespace utils { struct SockAddr { sockaddr_in addr; // constructor - SockAddr(void) {} + SockAddr() = default; SockAddr(const char *url, int port) { this->Set(url, port); } - inline static std::string GetHostName(void) { + inline static std::string GetHostName() { std::string buf; buf.resize(256); utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name"); return std::string(buf.c_str()); @@ -69,20 +69,20 @@ struct SockAddr { memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_INET; hints.ai_protocol = SOCK_STREAM; - addrinfo *res = NULL; - int sig = getaddrinfo(host, NULL, &hints, &res); - Check(sig == 0 && res != NULL, "cannot obtain address of %s", host); + addrinfo *res = nullptr; + int sig = getaddrinfo(host, nullptr, &hints, &res); + Check(sig == 0 && res != nullptr, "cannot obtain address of %s", host); Check(res->ai_family == AF_INET, "Does not support IPv6"); memcpy(&addr, res->ai_addr, res->ai_addrlen); addr.sin_port = htons(port); freeaddrinfo(res); } /*! \brief return port of the address*/ - inline int port(void) const { + inline int Port() const { return ntohs(addr.sin_port); } /*! \return a string representation of the address */ - inline std::string AddrStr(void) const { + inline std::string AddrStr() const { std::string buf; buf.resize(256); #ifdef _WIN32 const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr, @@ -91,7 +91,7 @@ struct SockAddr { const char *s = inet_ntop(AF_INET, &addr.sin_addr, &buf[0], buf.length()); #endif // _WIN32 - Assert(s != NULL, "cannot decode address"); + Assert(s != nullptr, "cannot decode address"); return std::string(s); } }; @@ -104,13 +104,13 @@ class Socket { /*! \brief the file descriptor of socket */ SOCKET sockfd; // default conversion to int - inline operator SOCKET() const { + operator SOCKET() const { // NOLINT return sockfd; } /*! * \return last error of socket operation */ - inline static int GetLastError(void) { + inline static int GetLastError() { #ifdef _WIN32 return WSAGetLastError(); #else @@ -118,7 +118,7 @@ class Socket { #endif // _WIN32 } /*! \return whether last error was would block */ - inline static bool LastErrorWouldBlock(void) { + inline static bool LastErrorWouldBlock() { int errsv = GetLastError(); #ifdef _WIN32 return errsv == WSAEWOULDBLOCK; @@ -130,7 +130,7 @@ class Socket { * \brief start up the socket module * call this before using the sockets */ - inline static void Startup(void) { + inline static void Startup() { #ifdef _WIN32 WSADATA wsa_data; if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) { @@ -145,7 +145,7 @@ class Socket { /*! * \brief shutdown the socket module after use, all sockets need to be closed */ - inline static void Finalize(void) { + inline static void Finalize() { #ifdef _WIN32 WSACleanup(); #endif // _WIN32 @@ -214,7 +214,7 @@ class Socket { return -1; } /*! \brief get last error code if any */ - inline int GetSockError(void) const { + inline int GetSockError() const { int error = 0; socklen_t len = sizeof(error); if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, @@ -224,25 +224,25 @@ class Socket { return error; } /*! \brief check if anything bad happens */ - inline bool BadSocket(void) const { + inline bool BadSocket() const { if (IsClosed()) return true; int err = GetSockError(); if (err == EBADF || err == EINTR) return true; return false; } /*! \brief check if socket is already closed */ - inline bool IsClosed(void) const { - return sockfd == INVALID_SOCKET; + inline bool IsClosed() const { + return sockfd == kInvalidSocket; } /*! \brief close the socket */ - inline void Close(void) { - if (sockfd != INVALID_SOCKET) { + inline void Close() { + if (sockfd != kInvalidSocket) { #ifdef _WIN32 closesocket(sockfd); #else close(sockfd); #endif - sockfd = INVALID_SOCKET; + sockfd = kInvalidSocket; } else { Error("Socket::Close double close the socket or close without create"); } @@ -268,7 +268,7 @@ class Socket { class TCPSocket : public Socket{ public: // constructor - TCPSocket(void) : Socket(INVALID_SOCKET) { + TCPSocket() : Socket(kInvalidSocket) { } explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) { } @@ -297,7 +297,7 @@ class TCPSocket : public Socket{ */ inline void Create(int af = PF_INET) { sockfd = socket(PF_INET, SOCK_STREAM, 0); - if (sockfd == INVALID_SOCKET) { + if (sockfd == kInvalidSocket) { Socket::Error("Create"); } } @@ -309,9 +309,9 @@ class TCPSocket : public Socket{ listen(sockfd, backlog); } /*! \brief get a new connection */ - TCPSocket Accept(void) { - SOCKET newfd = accept(sockfd, NULL, NULL); - if (newfd == INVALID_SOCKET) { + TCPSocket Accept() { + SOCKET newfd = accept(sockfd, nullptr, nullptr); + if (newfd == kInvalidSocket) { Socket::Error("Accept"); } return TCPSocket(newfd); @@ -320,7 +320,7 @@ class TCPSocket : public Socket{ * \brief decide whether the socket is at OOB mark * \return 1 if at mark, 0 if not, -1 if an error occured */ - inline int AtMark(void) const { + inline int AtMark() const { #ifdef _WIN32 unsigned long atmark; // NOLINT(*) if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1; diff --git a/rabit/include/rabit/internal/thread_local.h b/rabit/include/rabit/internal/thread_local.h deleted file mode 100644 index 4eebd6459..000000000 --- a/rabit/include/rabit/internal/thread_local.h +++ /dev/null @@ -1,87 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file thread_local.h - * \brief Common utility for thread local storage. - */ -#ifndef RABIT_INTERNAL_THREAD_LOCAL_H_ -#define RABIT_INTERNAL_THREAD_LOCAL_H_ - -#include "../include/dmlc/base.h" - -#if DMLC_ENABLE_STD_THREAD -#include -#endif // DMLC_ENABLE_STD_THREAD - -#include -#include - -namespace rabit { - -// macro hanlding for threadlocal variables -#ifdef __GNUC__ - #define MX_TREAD_LOCAL __thread -#elif __STDC_VERSION__ >= 201112L - #define MX_TREAD_LOCAL _Thread_local -#elif defined(_MSC_VER) - #define MX_TREAD_LOCAL __declspec(thread) -#endif // __GNUC__ - -#ifndef MX_TREAD_LOCAL -#message("Warning: Threadlocal is not enabled"); -#endif // MX_TREAD_LOCAL - -/*! - * \brief A threadlocal store to store threadlocal variables. - * Will return a thread local singleton of type T - * \tparam T the type we like to store - */ -template -class ThreadLocalStore { - public: - /*! \return get a thread local singleton */ - static T* Get() { - static MX_TREAD_LOCAL T* ptr = nullptr; - if (ptr == nullptr) { - ptr = new T(); - Singleton()->RegisterDelete(ptr); - } - return ptr; - } - - private: - /*! \brief constructor */ - ThreadLocalStore() {} - /*! \brief destructor */ - ~ThreadLocalStore() { - for (size_t i = 0; i < data_.size(); ++i) { - delete data_[i]; - } - } - /*! \return singleton of the store */ - static ThreadLocalStore *Singleton() { - static ThreadLocalStore inst; - return &inst; - } - /*! - * \brief register str for internal deletion - * \param str the string pointer - */ - void RegisterDelete(T *str) { -#if DMLC_ENABLE_STD_THREAD - std::unique_lock lock(mutex_); - data_.push_back(str); - lock.unlock(); -#else - data_.push_back(str); -#endif // DMLC_ENABLE_STD_THREAD - } - -#if DMLC_ENABLE_STD_THREAD - /*! \brief internal mutex */ - std::mutex mutex_; -#endif // DMLC_ENABLE_STD_THREAD - /*!\brief internal data */ - std::vector data_; -}; -} // namespace rabit -#endif // RABIT_INTERNAL_THREAD_LOCAL_H_ diff --git a/rabit/include/rabit/internal/timer.h b/rabit/include/rabit/internal/timer.h index 3ce1bf8f2..95a371027 100644 --- a/rabit/include/rabit/internal/timer.h +++ b/rabit/include/rabit/internal/timer.h @@ -6,7 +6,7 @@ */ #ifndef RABIT_INTERNAL_TIMER_H_ #define RABIT_INTERNAL_TIMER_H_ -#include +#include #ifdef __MACH__ #include #include @@ -18,7 +18,7 @@ namespace utils { /*! * \brief return time in seconds, not cross platform, avoid to use this in most places */ -inline double GetTime(void) { +inline double GetTime() { #ifdef __MACH__ clock_serv_t cclock; mach_timespec_t mts; diff --git a/rabit/include/rabit/internal/utils.h b/rabit/include/rabit/internal/utils.h index 5a6b43b6c..825d8e666 100644 --- a/rabit/include/rabit/internal/utils.h +++ b/rabit/include/rabit/internal/utils.h @@ -8,7 +8,7 @@ #define RABIT_INTERNAL_UTILS_H_ #include -#include +#include #include #include #include @@ -48,15 +48,7 @@ extern "C" { } #endif // _MSC_VER -#ifdef _MSC_VER -typedef unsigned char uint8_t; -typedef unsigned __int16 uint16_t; -typedef unsigned __int32 uint32_t; -typedef unsigned __int64 uint64_t; -typedef __int64 int64_t; -#else -#include -#endif // _MSC_VER +#include namespace rabit { /*! \brief namespace for helper utils of the project */ @@ -184,7 +176,7 @@ inline void Error(const char *fmt, ...) { /*! \brief replace fopen, report error when the file open fails */ inline std::FILE *FopenCheck(const char *fname, const char *flag) { std::FILE *fp = fopen64(fname, flag); - Check(fp != NULL, "can not open file \"%s\"\n", fname); + Check(fp != nullptr, "can not open file \"%s\"\n", fname); return fp; } } // namespace utils @@ -193,7 +185,7 @@ inline std::FILE *FopenCheck(const char *fname, const char *flag) { template inline T *BeginPtr(std::vector &vec) { // NOLINT(*) if (vec.size() == 0) { - return NULL; + return nullptr; } else { return &vec[0]; } @@ -202,17 +194,17 @@ inline T *BeginPtr(std::vector &vec) { // NOLINT(*) template inline const T *BeginPtr(const std::vector &vec) { // NOLINT(*) if (vec.size() == 0) { - return NULL; + return nullptr; } else { return &vec[0]; } } inline char* BeginPtr(std::string &str) { // NOLINT(*) - if (str.length() == 0) return NULL; + if (str.length() == 0) return nullptr; return &str[0]; } inline const char* BeginPtr(const std::string &str) { - if (str.length() == 0) return NULL; + if (str.length() == 0) return nullptr; return &str[0]; } } // namespace rabit diff --git a/rabit/include/rabit/rabit.h b/rabit/include/rabit/rabit.h index 396354e68..23c96c47f 100644 --- a/rabit/include/rabit/rabit.h +++ b/rabit/include/rabit/rabit.h @@ -53,12 +53,12 @@ namespace rabit { * \brief defines stream used in rabit * see definition of Stream in dmlc/io.h */ -typedef dmlc::Stream Stream; +using Stream = dmlc::Stream; /*! * \brief defines serializable objects used in rabit * see definition of Serializable in dmlc/io.h */ -typedef dmlc::Serializable Serializable; +using Serializable = dmlc::Serializable; /*! * \brief reduction operators namespace @@ -199,8 +199,8 @@ inline void Broadcast(std::string *sendrecv_data, int root, */ template inline void Allreduce(DType *sendrecvbuf, size_t count, - void (*prepare_fun)(void *) = NULL, - void *prepare_arg = NULL, + void (*prepare_fun)(void *) = nullptr, + void *prepare_arg = nullptr, const char* _file = _FILE, const int _line = _LINE, const char* _caller = _CALLER); @@ -220,7 +220,7 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, * \param _file caller file name used to generate unique cache key * \param _line caller line number used to generate unique cache key * \param _caller caller function name used to generate unique cache key -*/ +*/ template inline void Allgather(DType *sendrecvbuf_, size_t total_size, @@ -291,7 +291,7 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, * \sa CheckPoint, VersionNumber */ inline int LoadCheckPoint(Serializable *global_model, - Serializable *local_model = NULL); + Serializable *local_model = nullptr); /*! * \brief checkpoints the model, meaning a stage of execution has finished. * every time we call check point, a version number will be increased by one @@ -307,7 +307,7 @@ inline int LoadCheckPoint(Serializable *global_model, * \sa LoadCheckPoint, VersionNumber */ inline void CheckPoint(const Serializable *global_model, - const Serializable *local_model = NULL); + const Serializable *local_model = nullptr); /*! * \brief This function can be used to replace CheckPoint for global_model only, * when certain condition is met (see detailed explanation). @@ -366,8 +366,8 @@ class Reducer { * \param _caller caller function name used to generate unique cache key */ inline void Allreduce(DType *sendrecvbuf, size_t count, - void (*prepare_fun)(void *) = NULL, - void *prepare_arg = NULL, + void (*prepare_fun)(void *) = nullptr, + void *prepare_arg = nullptr, const char* _file = _FILE, const int _line = _LINE, const char* _caller = _CALLER); @@ -422,8 +422,8 @@ class SerializeReducer { */ inline void Allreduce(DType *sendrecvobj, size_t max_nbyte, size_t count, - void (*prepare_fun)(void *) = NULL, - void *prepare_arg = NULL, + void (*prepare_fun)(void *) = nullptr, + void *prepare_arg = nullptr, const char* _file = _FILE, const int _line = _LINE, const char* _caller = _CALLER); diff --git a/rabit/include/rabit/serializable.h b/rabit/include/rabit/serializable.h index 581262feb..77508292a 100644 --- a/rabit/include/rabit/serializable.h +++ b/rabit/include/rabit/serializable.h @@ -15,12 +15,12 @@ namespace rabit { * \brief defines stream used in rabit * see definition of Stream in dmlc/io.h */ -typedef dmlc::Stream Stream; +using Stream = dmlc::Stream ; /*! * \brief defines serializable objects used in rabit * see definition of Serializable in dmlc/io.h */ -typedef dmlc::Serializable Serializable; +using Serializable = dmlc::Serializable; } // namespace rabit #endif // RABIT_SERIALIZABLE_H_ diff --git a/rabit/src/allreduce_base.cc b/rabit/src/allreduce_base.cc index a5b199df8..34b5c6c33 100644 --- a/rabit/src/allreduce_base.cc +++ b/rabit/src/allreduce_base.cc @@ -15,7 +15,7 @@ namespace rabit { namespace engine { // constructor -AllreduceBase::AllreduceBase(void) { +AllreduceBase::AllreduceBase() { tracker_uri = "NULL"; tracker_port = 9000; host_uri = ""; @@ -24,7 +24,7 @@ AllreduceBase::AllreduceBase(void) { rank = 0; world_size = -1; connect_retry = 5; - hadoop_mode = 0; + hadoop_mode = false; version_number = 0; // 32 K items reduce_ring_mincount = 32 << 10; @@ -32,27 +32,27 @@ AllreduceBase::AllreduceBase(void) { tree_reduce_minsize = 1 << 20; // tracker URL task_id = "NULL"; - err_link = NULL; + err_link = nullptr; dmlc_role = "worker"; this->SetParam("rabit_reduce_buffer", "256MB"); // setup possible enviroment variable of interest // include dmlc support direct variables - env_vars.push_back("DMLC_TASK_ID"); - env_vars.push_back("DMLC_ROLE"); - env_vars.push_back("DMLC_NUM_ATTEMPT"); - env_vars.push_back("DMLC_TRACKER_URI"); - env_vars.push_back("DMLC_TRACKER_PORT"); - env_vars.push_back("DMLC_WORKER_CONNECT_RETRY"); + env_vars.emplace_back("DMLC_TASK_ID"); + env_vars.emplace_back("DMLC_ROLE"); + env_vars.emplace_back("DMLC_NUM_ATTEMPT"); + env_vars.emplace_back("DMLC_TRACKER_URI"); + env_vars.emplace_back("DMLC_TRACKER_PORT"); + env_vars.emplace_back("DMLC_WORKER_CONNECT_RETRY"); } // initialization function bool AllreduceBase::Init(int argc, char* argv[]) { // setup from enviroment variables // handler to get variables from env - for (size_t i = 0; i < env_vars.size(); ++i) { - const char *value = getenv(env_vars[i].c_str()); - if (value != NULL) { - this->SetParam(env_vars[i].c_str(), value); + for (auto & env_var : env_vars) { + const char *value = getenv(env_var.c_str()); + if (value != nullptr) { + this->SetParam(env_var.c_str(), value); } } // pass in arguments override env variable. @@ -66,35 +66,35 @@ bool AllreduceBase::Init(int argc, char* argv[]) { { // handling for hadoop const char *task_id = getenv("mapred_tip_id"); - if (task_id == NULL) { + if (task_id == nullptr) { task_id = getenv("mapreduce_task_id"); } if (hadoop_mode) { - utils::Check(task_id != NULL, + utils::Check(task_id != nullptr, "hadoop_mode is set but cannot find mapred_task_id"); } - if (task_id != NULL) { + if (task_id != nullptr) { this->SetParam("rabit_task_id", task_id); this->SetParam("rabit_hadoop_mode", "1"); } const char *attempt_id = getenv("mapred_task_id"); - if (attempt_id != 0) { + if (attempt_id != nullptr) { const char *att = strrchr(attempt_id, '_'); int num_trial; - if (att != NULL && sscanf(att + 1, "%d", &num_trial) == 1) { + if (att != nullptr && sscanf(att + 1, "%d", &num_trial) == 1) { this->SetParam("rabit_num_trial", att + 1); } } // handling for hadoop const char *num_task = getenv("mapred_map_tasks"); - if (num_task == NULL) { + if (num_task == nullptr) { num_task = getenv("mapreduce_job_maps"); } if (hadoop_mode) { - utils::Check(num_task != NULL, + utils::Check(num_task != nullptr, "hadoop_mode is set but cannot find mapred_map_tasks"); } - if (num_task != NULL) { + if (num_task != nullptr) { this->SetParam("rabit_world_size", num_task); } } @@ -115,10 +115,10 @@ bool AllreduceBase::Init(int argc, char* argv[]) { return this->ReConnectLinks(); } -bool AllreduceBase::Shutdown(void) { +bool AllreduceBase::Shutdown() { try { - for (size_t i = 0; i < all_links.size(); ++i) { - all_links[i].sock.Close(); + for (auto & all_link : all_links) { + all_link.sock.Close(); } all_links.clear(); tree_links.plinks.clear(); @@ -208,17 +208,18 @@ void AllreduceBase::SetParam(const char *name, const char *val) { utils::Assert(timeout_sec >= 0, "rabit_timeout_sec should be non negative second"); } if (!strcmp(name, "rabit_enable_tcp_no_delay")) { - if (!strcmp(val, "true")) + if (!strcmp(val, "true")) { rabit_enable_tcp_no_delay = true; - else + } else { rabit_enable_tcp_no_delay = false; +} } } /*! * \brief initialize connection to the tracker * \return a socket that initializes the connection */ -utils::TCPSocket AllreduceBase::ConnectTracker(void) const { +utils::TCPSocket AllreduceBase::ConnectTracker() const { int magic = kMagic; // get information from tracker utils::TCPSocket tracker; @@ -241,7 +242,7 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const { } } break; - } while (1); + } while (true); using utils::Assert; Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic), @@ -320,19 +321,19 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) { do { // send over good links std::vector good_link; - for (size_t i = 0; i < all_links.size(); ++i) { - if (!all_links[i].sock.BadSocket()) { - good_link.push_back(static_cast(all_links[i].rank)); + for (auto & all_link : all_links) { + if (!all_link.sock.BadSocket()) { + good_link.push_back(static_cast(all_link.rank)); } else { - if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close(); + if (!all_link.sock.IsClosed()) all_link.sock.Close(); } } int ngood = static_cast(good_link.size()); Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood), "ReConnectLink failure 5"); - for (size_t i = 0; i < good_link.size(); ++i) { - Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \ - sizeof(good_link[i]), "ReConnectLink failure 6"); + for (int & i : good_link) { + Assert(tracker.SendAll(&i, sizeof(i)) == \ + sizeof(i), "ReConnectLink failure 6"); } Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn), "ReConnectLink failure 7"); @@ -362,11 +363,11 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) { utils::Check(hrank == r.rank, "ReConnectLink failure, link rank inconsistent"); bool match = false; - for (size_t i = 0; i < all_links.size(); ++i) { - if (all_links[i].rank == hrank) { - Assert(all_links[i].sock.IsClosed(), + for (auto & all_link : all_links) { + if (all_link.rank == hrank) { + Assert(all_link.sock.IsClosed(), "Override a link that is active"); - all_links[i].sock = r.sock; + all_link.sock = r.sock; match = true; break; } @@ -390,11 +391,11 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) { Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 15"); bool match = false; - for (size_t i = 0; i < all_links.size(); ++i) { - if (all_links[i].rank == r.rank) { - utils::Assert(all_links[i].sock.IsClosed(), + for (auto & all_link : all_links) { + if (all_link.rank == r.rank) { + utils::Assert(all_link.sock.IsClosed(), "Override a link that is active"); - all_links[i].sock = r.sock; + all_link.sock = r.sock; match = true; break; } @@ -406,29 +407,29 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) { // setup tree links and ring structure tree_links.plinks.clear(); int tcpNoDelay = 1; - for (size_t i = 0; i < all_links.size(); ++i) { - utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket"); + for (auto & all_link : all_links) { + utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket"); // set the socket to non-blocking mode, enable TCP keepalive - all_links[i].sock.SetNonBlock(true); - all_links[i].sock.SetKeepAlive(true); + all_link.sock.SetNonBlock(true); + all_link.sock.SetKeepAlive(true); if (rabit_enable_tcp_no_delay) { - setsockopt(all_links[i].sock, IPPROTO_TCP, + setsockopt(all_link.sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&tcpNoDelay), sizeof(tcpNoDelay)); } - if (tree_neighbors.count(all_links[i].rank) != 0) { - if (all_links[i].rank == parent_rank) { + if (tree_neighbors.count(all_link.rank) != 0) { + if (all_link.rank == parent_rank) { parent_index = static_cast(tree_links.plinks.size()); } - tree_links.plinks.push_back(&all_links[i]); + tree_links.plinks.push_back(&all_link); } - if (all_links[i].rank == prev_rank) ring_prev = &all_links[i]; - if (all_links[i].rank == next_rank) ring_next = &all_links[i]; + if (all_link.rank == prev_rank) ring_prev = &all_link; + if (all_link.rank == next_rank) ring_next = &all_link; } Assert(parent_rank == -1 || parent_index != -1, "cannot find parent in the link"); - Assert(prev_rank == -1 || ring_prev != NULL, + Assert(prev_rank == -1 || ring_prev != nullptr, "cannot find prev ring in the link"); - Assert(next_rank == -1 || ring_next != NULL, + Assert(next_rank == -1 || ring_next != nullptr, "cannot find next ring in the link"); return true; } catch (const std::exception& e) { @@ -479,11 +480,11 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_, size_t count, ReduceFunction reducer) { RefLinkVector &links = tree_links; - if (links.size() == 0 || count == 0) return kSuccess; + if (links.Size() == 0 || count == 0) return kSuccess; // total size of message const size_t total_size = type_nbytes * count; // number of links - const int nlink = static_cast(links.size()); + const int nlink = static_cast(links.Size()); // send recv buffer char *sendrecvbuf = reinterpret_cast(sendrecvbuf_); // size of space that we already performs reduce in up pass @@ -581,8 +582,9 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_, // if max reduce is less than total size, we reduce multiple times of // eachreduce size - if (max_reduce < total_size) + if (max_reduce < total_size) { max_reduce = max_reduce - max_reduce % eachreduce; +} // peform reduce, can be at most two rounds while (size_up_reduce < max_reduce) { @@ -677,11 +679,11 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_, AllreduceBase::ReturnType AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { RefLinkVector &links = tree_links; - if (links.size() == 0 || total_size == 0) return kSuccess; + if (links.Size() == 0 || total_size == 0) return kSuccess; utils::Check(root < world_size, "Broadcast: root should be smaller than world size"); // number of links - const int nlink = static_cast(links.size()); + const int nlink = static_cast(links.Size()); // size of space already read from data size_t size_in = 0; // input link, -2 means unknown yet, -1 means this is root diff --git a/rabit/src/allreduce_base.h b/rabit/src/allreduce_base.h index c7ef638ff..9be23f83a 100644 --- a/rabit/src/allreduce_base.h +++ b/rabit/src/allreduce_base.h @@ -25,7 +25,7 @@ #endif // RABIT_CXXTESTDEFS_H -namespace MPI { +namespace MPI { // NOLINT // MPI data type to be compatible with existing MPI interface class Datatype { public: @@ -41,12 +41,12 @@ class AllreduceBase : public IEngine { // magic number to verify server static const int kMagic = 0xff99; // constant one byte out of band message to indicate error happening - AllreduceBase(void); - virtual ~AllreduceBase(void) {} + AllreduceBase(); + virtual ~AllreduceBase() = default; // initialize the manager virtual bool Init(int argc, char* argv[]); // shutdown the engine - virtual bool Shutdown(void); + virtual bool Shutdown(); /*! * \brief set parameters to the engine * \param name parameter name @@ -59,27 +59,27 @@ class AllreduceBase : public IEngine { * the user who monitors the tracker * \param msg message to be printed in the tracker */ - virtual void TrackerPrint(const std::string &msg); + void TrackerPrint(const std::string &msg) override; /*! \brief get rank of previous node in ring topology*/ - virtual int GetRingPrevRank(void) const { + int GetRingPrevRank() const override { return ring_prev->rank; } /*! \brief get rank */ - virtual int GetRank(void) const { + int GetRank() const override { return rank; } /*! \brief get rank */ - virtual int GetWorldSize(void) const { + int GetWorldSize() const override { if (world_size == -1) return 1; return world_size; } /*! \brief whether is distributed or not */ - virtual bool IsDistributed(void) const { + bool IsDistributed() const override { return tracker_uri != "NULL"; } /*! \brief get rank */ - virtual std::string GetHost(void) const { + std::string GetHost() const override { return host_uri; } @@ -99,13 +99,10 @@ class AllreduceBase : public IEngine { * \param _line caller line number used to generate unique cache key * \param _caller caller function name used to generate unique cache key */ - virtual void Allgather(void *sendrecvbuf_, size_t total_size, - size_t slice_begin, - size_t slice_end, - size_t size_prev_slice, - const char* _file = _FILE, - const int _line = _LINE, - const char* _caller = _CALLER) { + void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin, + size_t slice_end, size_t size_prev_slice, + const char *_file = _FILE, const int _line = _LINE, + const char *_caller = _CALLER) override { if (world_size == 1 || world_size == -1) return; utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size, slice_begin, slice_end, size_prev_slice) == kSuccess, @@ -126,16 +123,12 @@ class AllreduceBase : public IEngine { * \param _line caller line number used to generate unique cache key * \param _caller caller function name used to generate unique cache key */ - virtual void Allreduce(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer, - PreprocFunction prepare_fun = NULL, - void *prepare_arg = NULL, - const char* _file = _FILE, - const int _line = _LINE, - const char* _caller = _CALLER) { - if (prepare_fun != NULL) prepare_fun(prepare_arg); + void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, + ReduceFunction reducer, PreprocFunction prepare_fun = nullptr, + void *prepare_arg = nullptr, const char *_file = _FILE, + const int _line = _LINE, + const char *_caller = _CALLER) override { + if (prepare_fun != nullptr) prepare_fun(prepare_arg); if (world_size == 1 || world_size == -1) return; utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) == kSuccess, @@ -150,8 +143,9 @@ class AllreduceBase : public IEngine { * \param _line caller line number used to generate unique cache key * \param _caller caller function name used to generate unique cache key */ - virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root, - const char* _file = _FILE, const int _line = _LINE, const char* _caller = _CALLER) { + void Broadcast(void *sendrecvbuf_, size_t total_size, int root, + const char *_file = _FILE, const int _line = _LINE, + const char *_caller = _CALLER) override { if (world_size == 1 || world_size == -1) return; utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess, "Broadcast failed"); @@ -178,8 +172,8 @@ class AllreduceBase : public IEngine { * * \sa CheckPoint, VersionNumber */ - virtual int LoadCheckPoint(Serializable *global_model, - Serializable *local_model = NULL) { + int LoadCheckPoint(Serializable *global_model, + Serializable *local_model = nullptr) override { return 0; } /*! @@ -198,8 +192,8 @@ class AllreduceBase : public IEngine { * * \sa LoadCheckPoint, VersionNumber */ - virtual void CheckPoint(const Serializable *global_model, - const Serializable *local_model = NULL) { + void CheckPoint(const Serializable *global_model, + const Serializable *local_model = nullptr) override { version_number += 1; } /*! @@ -222,7 +216,7 @@ class AllreduceBase : public IEngine { * is the same in all nodes * \sa LoadCheckPoint, CheckPoint, VersionNumber */ - virtual void LazyCheckPoint(const Serializable *global_model) { + void LazyCheckPoint(const Serializable *global_model) override { version_number += 1; } /*! @@ -230,7 +224,7 @@ class AllreduceBase : public IEngine { * which means how many calls to CheckPoint we made so far * \sa LoadCheckPoint, CheckPoint */ - virtual int VersionNumber(void) const { + int VersionNumber() const override { return version_number; } /*! @@ -238,14 +232,14 @@ class AllreduceBase : public IEngine { * call this function when IEngine throw an exception out, * this function is only used for test purpose */ - virtual void InitAfterException(void) { + void InitAfterException() override { utils::Error("InitAfterException: not implemented"); } /*! * \brief report current status to the job tracker * depending on the job tracker we are in */ - inline void ReportStatus(void) const { + inline void ReportStatus() const { if (hadoop_mode != 0) { fprintf(stderr, "reporter:status:Rabit Phase[%03d] Operation %03d\n", version_number, seq_counter); @@ -274,7 +268,7 @@ class AllreduceBase : public IEngine { /*! \brief internal return type */ ReturnTypeEnum value; // constructor - ReturnType() {} + ReturnType() = default; ReturnType(ReturnTypeEnum value) : value(value) {} // NOLINT(*) inline bool operator==(const ReturnTypeEnum &v) const { return value == v; @@ -306,13 +300,11 @@ class AllreduceBase : public IEngine { // size of data sent to the link size_t size_write; // pointer to buffer head - char *buffer_head; + char *buffer_head {nullptr}; // buffer size, in bytes - size_t buffer_size; + size_t buffer_size {0}; // constructor - LinkRecord(void) - : buffer_head(NULL), buffer_size(0) { - } + LinkRecord() = default; // initialize buffer inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) { @@ -328,7 +320,7 @@ class AllreduceBase : public IEngine { buffer_head = reinterpret_cast(BeginPtr(buffer_)); } // reset the recv and sent size - inline void ResetSize(void) { + inline void ResetSize() { size_write = size_read = 0; } /*! @@ -340,7 +332,7 @@ class AllreduceBase : public IEngine { * \return the type of reading */ inline ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read) { - utils::Assert(buffer_head != NULL, "ReadToRingBuffer: buffer not allocated"); + utils::Assert(buffer_head != nullptr, "ReadToRingBuffer: buffer not allocated"); utils::Assert(size_read <= max_size_read, "ReadToRingBuffer: max_size_read check"); size_t ngap = size_read - protect_start; utils::Assert(ngap <= buffer_size, "Allreduce: boundary check"); @@ -405,7 +397,7 @@ class AllreduceBase : public IEngine { inline LinkRecord &operator[](size_t i) { return *plinks[i]; } - inline size_t size(void) const { + inline size_t Size() const { return plinks.size(); } }; @@ -413,7 +405,7 @@ class AllreduceBase : public IEngine { * \brief initialize connection to the tracker * \return a socket that initializes the connection */ - utils::TCPSocket ConnectTracker(void) const; + utils::TCPSocket ConnectTracker() const; /*! * \brief connect to the tracker to fix the the missing links * this function is also used when the engine start up @@ -525,64 +517,64 @@ class AllreduceBase : public IEngine { //---- data structure related to model ---- // call sequence counter, records how many calls we made so far // from last call to CheckPoint, LoadCheckPoint - int seq_counter; + int seq_counter; // NOLINT // version number of model - int version_number; + int version_number; // NOLINT // whether the job is running in hadoop - bool hadoop_mode; + bool hadoop_mode; // NOLINT //---- local data related to link ---- // index of parent link, can be -1, meaning this is root of the tree - int parent_index; + int parent_index; // NOLINT // rank of parent node, can be -1 - int parent_rank; + int parent_rank; // NOLINT // sockets of all links this connects to - std::vector all_links; + std::vector all_links; // NOLINT // used to record the link where things goes wrong - LinkRecord *err_link; + LinkRecord *err_link; // NOLINT // all the links in the reduction tree connection - RefLinkVector tree_links; + RefLinkVector tree_links; // NOLINT // pointer to links in the ring - LinkRecord *ring_prev, *ring_next; + LinkRecord *ring_prev, *ring_next; // NOLINT //----- meta information----- // list of enviroment variables that are of possible interest - std::vector env_vars; + std::vector env_vars; // NOLINT // unique identifier of the possible job this process is doing // used to assign ranks, optional, default to NULL - std::string task_id; + std::string task_id; // NOLINT // uri of current host, to be set by Init - std::string host_uri; + std::string host_uri; // NOLINT // uri of tracker - std::string tracker_uri; + std::string tracker_uri; // NOLINT // role in dmlc jobs - std::string dmlc_role; + std::string dmlc_role; // NOLINT // port of tracker address - int tracker_port; + int tracker_port; // NOLINT // port of slave process - int slave_port, nport_trial; + int slave_port, nport_trial; // NOLINT // reduce buffer size - size_t reduce_buffer_size; + size_t reduce_buffer_size; // NOLINT // reduction method - int reduce_method; + int reduce_method; // NOLINT // mininum count of cells to use ring based method - size_t reduce_ring_mincount; + size_t reduce_ring_mincount; // NOLINT // minimul block size per tree reduce - size_t tree_reduce_minsize; + size_t tree_reduce_minsize; // NOLINT // current rank - int rank; + int rank; // NOLINT // world size - int world_size; + int world_size; // NOLINT // connect retry time - int connect_retry; + int connect_retry; // NOLINT // enable bootstrap cache 0 false 1 true - bool rabit_bootstrap_cache = false; + bool rabit_bootstrap_cache = false; // NOLINT // enable detailed logging - bool rabit_debug = false; + bool rabit_debug = false; // NOLINT // by default, if rabit worker not recover in half an hour exit - int timeout_sec = 1800; + int timeout_sec = 1800; // NOLINT // flag to enable rabit_timeout - bool rabit_timeout = false; + bool rabit_timeout = false; // NOLINT // Enable TCP node delay - bool rabit_enable_tcp_no_delay = false; + bool rabit_enable_tcp_no_delay = false; // NOLINT }; } // namespace engine } // namespace rabit diff --git a/rabit/src/allreduce_mock.h b/rabit/src/allreduce_mock.h index ab9f0e0e7..7c0a25e80 100644 --- a/rabit/src/allreduce_mock.h +++ b/rabit/src/allreduce_mock.h @@ -20,74 +20,65 @@ namespace engine { class AllreduceMock : public AllreduceRobust { public: // constructor - AllreduceMock(void) { - num_trial = 0; - force_local = 0; - report_stats = 0; - tsum_allreduce = 0.0; - tsum_allgather = 0.0; + AllreduceMock() { + num_trial_ = 0; + force_local_ = 0; + report_stats_ = 0; + tsum_allreduce_ = 0.0; + tsum_allgather_ = 0.0; } // destructor - virtual ~AllreduceMock(void) {} - virtual void SetParam(const char *name, const char *val) { + ~AllreduceMock() override = default; + void SetParam(const char *name, const char *val) override { AllreduceRobust::SetParam(name, val); // additional parameters - if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val); - if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial = atoi(val); - if (!strcmp(name, "report_stats")) report_stats = atoi(val); - if (!strcmp(name, "force_local")) force_local = atoi(val); + if (!strcmp(name, "rabit_num_trial")) num_trial_ = atoi(val); + if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial_ = atoi(val); + if (!strcmp(name, "report_stats")) report_stats_ = atoi(val); + if (!strcmp(name, "force_local")) force_local_ = atoi(val); if (!strcmp(name, "mock")) { MockKey k; utils::Check(sscanf(val, "%d,%d,%d,%d", &k.rank, &k.version, &k.seqno, &k.ntrial) == 4, "invalid mock parameter"); - mock_map[k] = 1; + mock_map_[k] = 1; } } - virtual void Allreduce(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer, - PreprocFunction prepare_fun, - void *prepare_arg, - const char* _file = _FILE, - const int _line = _LINE, - const char* _caller = _CALLER) { - this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce"); + void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, + ReduceFunction reducer, PreprocFunction prepare_fun, + void *prepare_arg, const char *_file = _FILE, + const int _line = _LINE, + const char *_caller = _CALLER) override { + this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "AllReduce"); double tstart = utils::GetTime(); AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes, count, reducer, prepare_fun, prepare_arg, _file, _line, _caller); - tsum_allreduce += utils::GetTime() - tstart; + tsum_allreduce_ += utils::GetTime() - tstart; } - virtual void Allgather(void *sendrecvbuf, - size_t total_size, - size_t slice_begin, - size_t slice_end, - size_t size_prev_slice, - const char* _file = _FILE, - const int _line = _LINE, - const char* _caller = _CALLER) { - this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Allgather"); + void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, + size_t slice_end, size_t size_prev_slice, + const char *_file = _FILE, const int _line = _LINE, + const char *_caller = _CALLER) override { + this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Allgather"); double tstart = utils::GetTime(); AllreduceRobust::Allgather(sendrecvbuf, total_size, slice_begin, slice_end, size_prev_slice, _file, _line, _caller); - tsum_allgather += utils::GetTime() - tstart; + tsum_allgather_ += utils::GetTime() - tstart; } - virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root, - const char* _file = _FILE, - const int _line = _LINE, - const char* _caller = _CALLER) { - this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast"); + void Broadcast(void *sendrecvbuf_, size_t total_size, int root, + const char *_file = _FILE, const int _line = _LINE, + const char *_caller = _CALLER) override { + this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Broadcast"); AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller); } - virtual int LoadCheckPoint(Serializable *global_model, - Serializable *local_model) { - tsum_allreduce = 0.0; - tsum_allgather = 0.0; - time_checkpoint = utils::GetTime(); - if (force_local == 0) { + int LoadCheckPoint(Serializable *global_model, + Serializable *local_model) override { + tsum_allreduce_ = 0.0; + tsum_allgather_ = 0.0; + time_checkpoint_ = utils::GetTime(); + if (force_local_ == 0) { return AllreduceRobust::LoadCheckPoint(global_model, local_model); } else { DummySerializer dum; @@ -95,56 +86,54 @@ class AllreduceMock : public AllreduceRobust { return AllreduceRobust::LoadCheckPoint(&dum, &com); } } - virtual void CheckPoint(const Serializable *global_model, - const Serializable *local_model) { - this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint"); + void CheckPoint(const Serializable *global_model, + const Serializable *local_model) override { + this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "CheckPoint"); double tstart = utils::GetTime(); - double tbet_chkpt = tstart - time_checkpoint; - if (force_local == 0) { + double tbet_chkpt = tstart - time_checkpoint_; + if (force_local_ == 0) { AllreduceRobust::CheckPoint(global_model, local_model); } else { DummySerializer dum; ComboSerializer com(global_model, local_model); AllreduceRobust::CheckPoint(&dum, &com); } - time_checkpoint = utils::GetTime(); + time_checkpoint_ = utils::GetTime(); double tcost = utils::GetTime() - tstart; - if (report_stats != 0 && rank == 0) { + if (report_stats_ != 0 && rank == 0) { std::stringstream ss; - ss << "[v" << version_number << "] global_size=" << global_checkpoint.length() - << ",local_size=" << (local_chkpt[0].length() + local_chkpt[1].length()) + ss << "[v" << version_number << "] global_size=" << global_checkpoint_.length() + << ",local_size=" << (local_chkpt_[0].length() + local_chkpt_[1].length()) << ",check_tcost="<< tcost <<" sec" - << ",allreduce_tcost=" << tsum_allreduce << " sec" - << ",allgather_tcost=" << tsum_allgather << " sec" + << ",allreduce_tcost=" << tsum_allreduce_ << " sec" + << ",allgather_tcost=" << tsum_allgather_ << " sec" << ",between_chpt=" << tbet_chkpt << "sec\n"; this->TrackerPrint(ss.str()); } - tsum_allreduce = 0.0; - tsum_allgather = 0.0; + tsum_allreduce_ = 0.0; + tsum_allgather_ = 0.0; } - virtual void LazyCheckPoint(const Serializable *global_model) { - this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint"); + void LazyCheckPoint(const Serializable *global_model) override { + this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "LazyCheckPoint"); AllreduceRobust::LazyCheckPoint(global_model); } protected: // force checkpoint to local - int force_local; + int force_local_; // whether report statistics - int report_stats; + int report_stats_; // sum of allreduce - double tsum_allreduce; + double tsum_allreduce_; // sum of allgather - double tsum_allgather; - double time_checkpoint; + double tsum_allgather_; + double time_checkpoint_; private: struct DummySerializer : public Serializable { - virtual void Load(Stream *fi) { - } - virtual void Save(Stream *fo) const { - } + void Load(Stream *fi) override {} + void Save(Stream *fo) const override {} }; struct ComboSerializer : public Serializable { Serializable *lhs; @@ -155,15 +144,15 @@ class AllreduceMock : public AllreduceRobust { : lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) { } ComboSerializer(const Serializable *lhs, const Serializable *rhs) - : lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) { + : lhs(nullptr), rhs(nullptr), c_lhs(lhs), c_rhs(rhs) { } - virtual void Load(Stream *fi) { - if (lhs != NULL) lhs->Load(fi); - if (rhs != NULL) rhs->Load(fi); + void Load(Stream *fi) override { + if (lhs != nullptr) lhs->Load(fi); + if (rhs != nullptr) rhs->Load(fi); } - virtual void Save(Stream *fo) const { - if (c_lhs != NULL) c_lhs->Save(fo); - if (c_rhs != NULL) c_rhs->Save(fo); + void Save(Stream *fo) const override { + if (c_lhs != nullptr) c_lhs->Save(fo); + if (c_rhs != nullptr) c_rhs->Save(fo); } }; // key to identify the mock stage @@ -172,7 +161,7 @@ class AllreduceMock : public AllreduceRobust { int version; int seqno; int ntrial; - MockKey(void) {} + MockKey() = default; MockKey(int rank, int version, int seqno, int ntrial) : rank(rank), version(version), seqno(seqno), ntrial(ntrial) {} inline bool operator==(const MockKey &b) const { @@ -189,15 +178,15 @@ class AllreduceMock : public AllreduceRobust { } }; // number of failure trials - int num_trial; + int num_trial_; // record all mock actions - std::map mock_map; + std::map mock_map_; // used to generate all kinds of exceptions inline void Verify(const MockKey &key, const char *name) { - if (mock_map.count(key) != 0) { - num_trial += 1; + if (mock_map_.count(key) != 0) { + num_trial_ += 1; // data processing frameworks runs on shared process - _error("[%d]@@@Hit Mock Error:%s ", rank, name); + error_("[%d]@@@Hit Mock Error:%s ", rank, name); } } }; diff --git a/rabit/src/allreduce_robust-inl.h b/rabit/src/allreduce_robust-inl.h index 7baa14bff..7dcbbb456 100644 --- a/rabit/src/allreduce_robust-inl.h +++ b/rabit/src/allreduce_robust-inl.h @@ -40,9 +40,9 @@ AllreduceRobust::MsgPassing(const NodeType &node_value, const std::vector &edge_in, size_t out_index)) { RefLinkVector &links = tree_links; - if (links.size() == 0) return kSuccess; + if (links.Size() == 0) return kSuccess; // number of links - const int nlink = static_cast(links.size()); + const int nlink = static_cast(links.Size()); // initialize the pointers for (int i = 0; i < nlink; ++i) { links[i].ResetSize(); diff --git a/rabit/src/allreduce_robust.cc b/rabit/src/allreduce_robust.cc index de962055f..8fb5f8183 100644 --- a/rabit/src/allreduce_robust.cc +++ b/rabit/src/allreduce_robust.cc @@ -23,31 +23,32 @@ namespace rabit { namespace engine { -AllreduceRobust::AllreduceRobust(void) { - num_local_replica = 0; - num_global_replica = 5; - default_local_replica = 2; +AllreduceRobust::AllreduceRobust() { + num_local_replica_ = 0; + num_global_replica_ = 5; + default_local_replica_ = 2; seq_counter = 0; - cur_cache_seq = 0; - local_chkpt_version = 0; - result_buffer_round = 1; - global_lazycheck = NULL; - use_local_model = -1; - recover_counter = 0; - checkpoint_loaded = false; - env_vars.push_back("rabit_global_replica"); - env_vars.push_back("rabit_local_replica"); + cur_cache_seq_ = 0; + local_chkpt_version_ = 0; + result_buffer_round_ = 1; + global_lazycheck_ = nullptr; + use_local_model_ = -1; + recover_counter_ = 0; + checkpoint_loaded_ = false; + env_vars.emplace_back("rabit_global_replica"); + env_vars.emplace_back("rabit_local_replica"); } bool AllreduceRobust::Init(int argc, char* argv[]) { if (AllreduceBase::Init(argc, argv)) { // chenqin: alert user opted in experimental feature. - if (rabit_bootstrap_cache) utils::HandleLogInfo( + if (rabit_bootstrap_cache) { utils::HandleLogInfo( "[EXPERIMENTAL] bootstrap cache has been enabled\n"); - checkpoint_loaded = false; - if (num_global_replica == 0) { - result_buffer_round = -1; +} + checkpoint_loaded_ = false; + if (num_global_replica_ == 0) { + result_buffer_round_ = -1; } else { - result_buffer_round = std::max(world_size / num_global_replica, 1); + result_buffer_round_ = std::max(world_size / num_global_replica_, 1); } return true; } else { @@ -55,27 +56,27 @@ bool AllreduceRobust::Init(int argc, char* argv[]) { } } /*! \brief shutdown the engine */ -bool AllreduceRobust::Shutdown(void) { +bool AllreduceRobust::Shutdown() { try { // need to sync the exec before we shutdown, do a pesudo check point // execute checkpoint, note: when checkpoint existing, load will not happen - _assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp, - cur_cache_seq), "Shutdown: check point must return true"); + assert_(RecoverExec(nullptr, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp, + cur_cache_seq_), "Shutdown: check point must return true"); // reset result buffer - resbuf.Clear(); seq_counter = 0; - cachebuf.Clear(); cur_cache_seq = 0; - lookupbuf.Clear(); + resbuf_.Clear(); seq_counter = 0; + cachebuf_.Clear(); cur_cache_seq_ = 0; + lookupbuf_.Clear(); // execute check ack step, load happens here - _assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, - ActionSummary::kSpecialOp, cur_cache_seq), "Shutdown: check ack must return true"); + assert_(RecoverExec(nullptr, 0, ActionSummary::kCheckAck, + ActionSummary::kSpecialOp, cur_cache_seq_), "Shutdown: check ack must return true"); // travis ci only osx test hang #if defined (__APPLE__) sleep(1); #endif - shutdown_timeout = true; - if (rabit_timeout_task.valid()) { - rabit_timeout_task.wait(); - _assert(rabit_timeout_task.get(), "expect timeout task return\n"); + shutdown_timeout_ = true; + if (rabit_timeout_task_.valid()) { + rabit_timeout_task_.wait(); + assert_(rabit_timeout_task_.get(), "expect timeout task return\n"); } return AllreduceBase::Shutdown(); } catch (const std::exception& e) { @@ -91,50 +92,48 @@ bool AllreduceRobust::Shutdown(void) { */ void AllreduceRobust::SetParam(const char *name, const char *val) { AllreduceBase::SetParam(name, val); - if (!strcmp(name, "rabit_global_replica")) num_global_replica = atoi(val); + if (!strcmp(name, "rabit_global_replica")) num_global_replica_ = atoi(val); if (!strcmp(name, "rabit_local_replica")) { - num_local_replica = atoi(val); + num_local_replica_ = atoi(val); } } int AllreduceRobust::SetBootstrapCache(const std::string &key, const void *buf, const size_t type_nbytes, const size_t count) { - int index = -1; - for (int i = 0 ; i < cur_cache_seq; i++) { + for (int i = 0 ; i < cur_cache_seq_; i++) { size_t nsize = 0; - void* name = lookupbuf.Query(i, &nsize); + void* name = lookupbuf_.Query(i, &nsize); if (nsize == key.length() + 1 && strcmp(static_cast(name), key.c_str()) == 0) { - index = i; break; } } // we should consider way to support duplicated signatures // https://github.com/dmlc/xgboost/issues/5012 // _assert(index == -1, "immutable cache key already exists"); - _assert(type_nbytes*count > 0, "can't set empty cache"); - void* temp = cachebuf.AllocTemp(type_nbytes, count); - cachebuf.PushTemp(cur_cache_seq, type_nbytes, count); + assert_(type_nbytes*count > 0, "can't set empty cache"); + void* temp = cachebuf_.AllocTemp(type_nbytes, count); + cachebuf_.PushTemp(cur_cache_seq_, type_nbytes, count); std::memcpy(temp, buf, type_nbytes*count); std::string k(key); - void* name = lookupbuf.AllocTemp(strlen(k.c_str()) + 1, 1); - lookupbuf.PushTemp(cur_cache_seq, strlen(k.c_str()) + 1, 1); + void* name = lookupbuf_.AllocTemp(strlen(k.c_str()) + 1, 1); + lookupbuf_.PushTemp(cur_cache_seq_, strlen(k.c_str()) + 1, 1); std::memcpy(name, key.c_str(), strlen(k.c_str()) + 1); - cur_cache_seq += 1; + cur_cache_seq_ += 1; return 0; } int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf, const size_t type_nbytes, const size_t count) { // as requester sync with rest of nodes on latest cache content - if (!RecoverExec(NULL, 0, ActionSummary::kLoadBootstrapCache, - seq_counter, cur_cache_seq)) return -1; + if (!RecoverExec(nullptr, 0, ActionSummary::kLoadBootstrapCache, + seq_counter, cur_cache_seq_)) return -1; int index = -1; - for (int i = 0 ; i < cur_cache_seq; i++) { + for (int i = 0 ; i < cur_cache_seq_; i++) { size_t nsize = 0; - void* name = lookupbuf.Query(i, &nsize); + void* name = lookupbuf_.Query(i, &nsize); if (nsize == strlen(key.c_str()) + 1 && strcmp(reinterpret_cast(name), key.c_str()) == 0) { index = i; @@ -145,8 +144,8 @@ int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf, if (index == -1) return -1; size_t siz = 0; - void* temp = cachebuf.Query(index, &siz); - utils::Assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index"); + void* temp = cachebuf_.Query(index, &siz); + utils::Assert(cur_cache_seq_ > index, "cur_cache_seq is smaller than lookup cache seq index"); utils::Assert(siz == type_nbytes*count, "cache size stored expected to be same as requested"); utils::Assert(siz > 0, "cache size should be greater than 0"); std::memcpy(buf, temp, type_nbytes*count); @@ -183,19 +182,19 @@ void AllreduceRobust::Allgather(void *sendrecvbuf, + std::string(_caller) + "#" +std::to_string(total_size); // try fetch bootstrap allgather results from cache - if (!checkpoint_loaded && rabit_bootstrap_cache && + if (!checkpoint_loaded_ && rabit_bootstrap_cache && GetBootstrapCache(key, sendrecvbuf, total_size, 1) != -1) return; double start = utils::GetTime(); - bool recovered = RecoverExec(sendrecvbuf, total_size, 0, seq_counter, cur_cache_seq); + bool recovered = RecoverExec(sendrecvbuf, total_size, 0, seq_counter, cur_cache_seq_); - if (resbuf.LastSeqNo() != -1 && - (result_buffer_round == -1 || - resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) { - resbuf.DropLast(); + if (resbuf_.LastSeqNo() != -1 && + (result_buffer_round_ == -1 || + resbuf_.LastSeqNo() % result_buffer_round_ != rank % result_buffer_round_)) { + resbuf_.DropLast(); } - void *temp = resbuf.AllocTemp(total_size, 1); + void *temp = resbuf_.AllocTemp(total_size, 1); while (true) { if (recovered) { std::memcpy(temp, sendrecvbuf, total_size); break; @@ -205,7 +204,7 @@ void AllreduceRobust::Allgather(void *sendrecvbuf, slice_begin, slice_end, size_prev_slice))) { std::memcpy(sendrecvbuf, temp, total_size); break; } else { - recovered = RecoverExec(sendrecvbuf, total_size, 0, seq_counter, cur_cache_seq); + recovered = RecoverExec(sendrecvbuf, total_size, 0, seq_counter, cur_cache_seq_); } } } @@ -217,8 +216,8 @@ void AllreduceRobust::Allgather(void *sendrecvbuf, } // if bootstrap allgather, store and fetch through cache - if (checkpoint_loaded || !rabit_bootstrap_cache) { - resbuf.PushTemp(seq_counter, total_size, 1); + if (checkpoint_loaded_ || !rabit_bootstrap_cache) { + resbuf_.PushTemp(seq_counter, total_size, 1); seq_counter += 1; } else { SetBootstrapCache(key, sendrecvbuf, total_size, 1); @@ -251,7 +250,7 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_, const char* _caller) { // skip action in single node if (world_size == 1 || world_size == -1) { - if (prepare_fun != NULL) prepare_fun(prepare_arg); + if (prepare_fun != nullptr) prepare_fun(prepare_arg); return; } @@ -260,20 +259,20 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_, + std::string(_caller) + "#" +std::to_string(type_nbytes) + "x" + std::to_string(count); // try fetch bootstrap allreduce results from cache - if (!checkpoint_loaded && rabit_bootstrap_cache && + if (!checkpoint_loaded_ && rabit_bootstrap_cache && GetBootstrapCache(key, sendrecvbuf_, type_nbytes, count) != -1) return; double start = utils::GetTime(); - bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter, cur_cache_seq); + bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter, cur_cache_seq_); - if (resbuf.LastSeqNo() != -1 && - (result_buffer_round == -1 || - resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) { - resbuf.DropLast(); + if (resbuf_.LastSeqNo() != -1 && + (result_buffer_round_ == -1 || + resbuf_.LastSeqNo() % result_buffer_round_ != rank % result_buffer_round_)) { + resbuf_.DropLast(); } - if (!recovered && prepare_fun != NULL) prepare_fun(prepare_arg); - void *temp = resbuf.AllocTemp(type_nbytes, count); + if (!recovered && prepare_fun != nullptr) prepare_fun(prepare_arg); + void *temp = resbuf_.AllocTemp(type_nbytes, count); while (true) { if (recovered) { std::memcpy(temp, sendrecvbuf_, type_nbytes * count); break; @@ -282,7 +281,7 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_, if (CheckAndRecover(TryAllreduce(temp, type_nbytes, count, reducer))) { std::memcpy(sendrecvbuf_, temp, type_nbytes * count); break; } else { - recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter, cur_cache_seq); + recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter, cur_cache_seq_); } } } @@ -294,8 +293,8 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_, } // if bootstrap allreduce, store and fetch through cache - if (checkpoint_loaded || !rabit_bootstrap_cache) { - resbuf.PushTemp(seq_counter, type_nbytes, count); + if (checkpoint_loaded_ || !rabit_bootstrap_cache) { + resbuf_.PushTemp(seq_counter, type_nbytes, count); seq_counter += 1; } else { SetBootstrapCache(key, sendrecvbuf_, type_nbytes, count); @@ -320,17 +319,19 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root, std::string key = std::string(_file) + "::" + std::to_string(_line) + "::" + std::string(_caller) + "#" +std::to_string(total_size) + "@" + std::to_string(root); // try fetch bootstrap allreduce results from cache - if (!checkpoint_loaded && rabit_bootstrap_cache && - GetBootstrapCache(key, sendrecvbuf_, total_size, 1) != -1) return; - double start = utils::GetTime(); - bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter, cur_cache_seq); - // now we are free to remove the last result, if any - if (resbuf.LastSeqNo() != -1 && - (result_buffer_round == -1 || - resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) { - resbuf.DropLast(); + if (!checkpoint_loaded_ && rabit_bootstrap_cache && + GetBootstrapCache(key, sendrecvbuf_, total_size, 1) != -1) { + return; } - void *temp = resbuf.AllocTemp(1, total_size); + double start = utils::GetTime(); + bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter, cur_cache_seq_); + // now we are free to remove the last result, if any + if (resbuf_.LastSeqNo() != -1 && + (result_buffer_round_ == -1 || + resbuf_.LastSeqNo() % result_buffer_round_ != rank % result_buffer_round_)) { + resbuf_.DropLast(); + } + void *temp = resbuf_.AllocTemp(1, total_size); while (true) { if (recovered) { std::memcpy(temp, sendrecvbuf_, total_size); break; @@ -338,7 +339,7 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root, if (CheckAndRecover(TryBroadcast(sendrecvbuf_, total_size, root))) { std::memcpy(temp, sendrecvbuf_, total_size); break; } else { - recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter, cur_cache_seq); + recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter, cur_cache_seq_); } } } @@ -351,8 +352,8 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root, rank, key.c_str(), root, version_number, seq_counter, delta); } // if bootstrap broadcast, store and fetch through cache - if (checkpoint_loaded || !rabit_bootstrap_cache) { - resbuf.PushTemp(seq_counter, 1, total_size); + if (checkpoint_loaded_ || !rabit_bootstrap_cache) { + resbuf_.PushTemp(seq_counter, 1, total_size); seq_counter += 1; } else { SetBootstrapCache(key, sendrecvbuf_, total_size, 1); @@ -382,46 +383,46 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root, */ int AllreduceRobust::LoadCheckPoint(Serializable *global_model, Serializable *local_model) { - checkpoint_loaded = true; + checkpoint_loaded_ = true; // skip action in single node if (world_size == 1) return 0; - this->LocalModelCheck(local_model != NULL); - if (num_local_replica == 0) { - utils::Check(local_model == NULL, + this->LocalModelCheck(local_model != nullptr); + if (num_local_replica_ == 0) { + utils::Check(local_model == nullptr, "need to set rabit_local_replica larger than 1 to checkpoint local_model"); } double start = utils::GetTime(); // check if we succeed - if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp, cur_cache_seq)) { - int nlocal = std::max(static_cast(local_rptr[local_chkpt_version].size()) - 1, 0); - if (local_model != NULL) { - if (nlocal == num_local_replica + 1) { + if (RecoverExec(nullptr, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp, cur_cache_seq_)) { + int nlocal = std::max(static_cast(local_rptr_[local_chkpt_version_].size()) - 1, 0); + if (local_model != nullptr) { + if (nlocal == num_local_replica_ + 1) { // load in local model - utils::MemoryFixSizeBuffer fs(BeginPtr(local_chkpt[local_chkpt_version]), - local_rptr[local_chkpt_version][1]); + utils::MemoryFixSizeBuffer fs(BeginPtr(local_chkpt_[local_chkpt_version_]), + local_rptr_[local_chkpt_version_][1]); local_model->Load(&fs); } else { - _assert(nlocal == 0, "[%d] local model inconsistent, nlocal=%d", rank, nlocal); + assert_(nlocal == 0, "[%d] local model inconsistent, nlocal=%d", rank, nlocal); } } // reset result buffer - resbuf.Clear(); seq_counter = 0; + resbuf_.Clear(); seq_counter = 0; // load from buffer - utils::MemoryBufferStream fs(&global_checkpoint); - if (global_checkpoint.length() == 0) { + utils::MemoryBufferStream fs(&global_checkpoint_); + if (global_checkpoint_.length() == 0) { version_number = 0; } else { - _assert(fs.Read(&version_number, sizeof(version_number)) != 0, + assert_(fs.Read(&version_number, sizeof(version_number)) != 0, "read in version number"); global_model->Load(&fs); - _assert(local_model == NULL || nlocal == num_local_replica + 1, + assert_(local_model == nullptr || nlocal == num_local_replica_ + 1, "local model inconsistent, nlocal=%d", nlocal); } // run another phase of check ack, if recovered from data - _assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, - ActionSummary::kSpecialOp, cur_cache_seq), "check ack must return true"); + assert_(RecoverExec(nullptr, 0, ActionSummary::kCheckAck, + ActionSummary::kSpecialOp, cur_cache_seq_), "check ack must return true"); - if (!RecoverExec(NULL, 0, ActionSummary::kLoadBootstrapCache, seq_counter, cur_cache_seq)) { + if (!RecoverExec(nullptr, 0, ActionSummary::kLoadBootstrapCache, seq_counter, cur_cache_seq_)) { utils::Printf("no need to load cache\n"); } double delta = utils::GetTime() - start; @@ -430,7 +431,7 @@ int AllreduceRobust::LoadCheckPoint(Serializable *global_model, if (rabit_debug) { utils::HandleLogInfo("[%d] loadcheckpoint size %ld finished version %d, " "seq %d, take %f seconds\n", - rank, global_checkpoint.length(), + rank, global_checkpoint_.length(), version_number, seq_counter, delta); } return version_number; @@ -439,7 +440,7 @@ int AllreduceRobust::LoadCheckPoint(Serializable *global_model, if (rabit_debug) utils::HandleLogInfo("[%d] loadcheckpoint reset\n", rank); // reset result buffer - resbuf.Clear(); seq_counter = 0; version_number = 0; + resbuf_.Clear(); seq_counter = 0; version_number = 0; // nothing loaded, a fresh start, everyone init model return version_number; } @@ -453,18 +454,18 @@ int AllreduceRobust::LoadCheckPoint(Serializable *global_model, * \param with_local whether the user calls CheckPoint with local model */ void AllreduceRobust::LocalModelCheck(bool with_local) { - if (use_local_model == -1) { + if (use_local_model_ == -1) { if (with_local) { - use_local_model = 1; - if (num_local_replica == 0) { - num_local_replica = default_local_replica; + use_local_model_ = 1; + if (num_local_replica_ == 0) { + num_local_replica_ = default_local_replica_; } } else { - use_local_model = 0; - num_local_replica = 0; + use_local_model_ = 0; + num_local_replica_ = 0; } } else { - utils::Check(use_local_model == static_cast(with_local), + utils::Check(use_local_model_ == static_cast(with_local), "Can only call Checkpoint/LoadCheckPoint always with"\ "or without local_model, but not mixed case"); } @@ -481,57 +482,57 @@ void AllreduceRobust::LocalModelCheck(bool with_local) { * * \sa CheckPoint, LazyCheckPoint */ -void AllreduceRobust::CheckPoint_(const Serializable *global_model, - const Serializable *local_model, - bool lazy_checkpt) { +void AllreduceRobust::CheckPointImpl(const Serializable *global_model, + const Serializable *local_model, + bool lazy_checkpt) { // never do check point in single machine mode if (world_size == 1) { version_number += 1; return; } double start = utils::GetTime(); - this->LocalModelCheck(local_model != NULL); - if (num_local_replica == 0) { - utils::Check(local_model == NULL, + this->LocalModelCheck(local_model != nullptr); + if (num_local_replica_ == 0) { + utils::Check(local_model == nullptr, "need to set rabit_local_replica larger than 1 to checkpoint local_model"); } - if (num_local_replica != 0) { + if (num_local_replica_ != 0) { while (true) { - if (RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckPoint)) break; + if (RecoverExec(nullptr, 0, 0, ActionSummary::kLocalCheckPoint)) break; // save model to new version place - int new_version = !local_chkpt_version; + int new_version = !local_chkpt_version_; - local_chkpt[new_version].clear(); - utils::MemoryBufferStream fs(&local_chkpt[new_version]); - if (local_model != NULL) { + local_chkpt_[new_version].clear(); + utils::MemoryBufferStream fs(&local_chkpt_[new_version]); + if (local_model != nullptr) { local_model->Save(&fs); } - local_rptr[new_version].clear(); - local_rptr[new_version].push_back(0); - local_rptr[new_version].push_back(local_chkpt[new_version].length()); - if (CheckAndRecover(TryCheckinLocalState(&local_rptr[new_version], - &local_chkpt[new_version]))) break; + local_rptr_[new_version].clear(); + local_rptr_[new_version].push_back(0); + local_rptr_[new_version].push_back(local_chkpt_[new_version].length()); + if (CheckAndRecover(TryCheckinLocalState(&local_rptr_[new_version], + &local_chkpt_[new_version]))) break; } // run the ack phase, can be true or false - RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckAck); + RecoverExec(nullptr, 0, 0, ActionSummary::kLocalCheckAck); // switch pointer to new version - local_chkpt_version = !local_chkpt_version; + local_chkpt_version_ = !local_chkpt_version_; } // execute checkpoint, note: when checkpoint existing, load will not happen - _assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, - ActionSummary::kSpecialOp, cur_cache_seq), + assert_(RecoverExec(nullptr, 0, ActionSummary::kCheckPoint, + ActionSummary::kSpecialOp, cur_cache_seq_), "check point must return true"); // this is the critical region where we will change all the stored models // increase version number version_number += 1; // save model if (lazy_checkpt) { - global_lazycheck = global_model; + global_lazycheck_ = global_model; } else { - global_checkpoint.resize(0); - utils::MemoryBufferStream fs(&global_checkpoint); + global_checkpoint_.resize(0); + utils::MemoryBufferStream fs(&global_checkpoint_); fs.Write(&version_number, sizeof(version_number)); global_model->Save(&fs); - global_lazycheck = NULL; + global_lazycheck_ = nullptr; } double delta = utils::GetTime() - start; // log checkpoint latency @@ -542,10 +543,10 @@ void AllreduceRobust::CheckPoint_(const Serializable *global_model, } start = utils::GetTime(); // reset result buffer, mark boostrap phase complete - resbuf.Clear(); seq_counter = 0; + resbuf_.Clear(); seq_counter = 0; // execute check ack step, load happens here - _assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, - ActionSummary::kSpecialOp, cur_cache_seq), "check ack must return true"); + assert_(RecoverExec(nullptr, 0, ActionSummary::kCheckAck, + ActionSummary::kSpecialOp, cur_cache_seq_), "check ack must return true"); delta = utils::GetTime() - start; // log checkpoint ack latency @@ -564,7 +565,7 @@ void AllreduceRobust::CheckPoint_(const Serializable *global_model, * when kSockError is returned, it simply means there are bad sockets in the links, * and some link recovery proceduer is needed */ -AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { +AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks() { // number of links const int nlink = static_cast(all_links.size()); for (int i = 0; i < nlink; ++i) { @@ -618,7 +619,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { if (all_links[i].size_read == 0) { int atmark = all_links[i].sock.AtMark(); if (atmark < 0) { - _assert(all_links[i].sock.BadSocket(), "must already gone bad"); + assert_(all_links[i].sock.BadSocket(), "must already gone bad"); } else if (atmark > 0) { all_links[i].size_read = 1; } else { @@ -640,10 +641,10 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { if (len == 0) { all_links[i].sock.Close(); continue; } else if (len > 0) { - _assert(oob_mark == kResetMark, "wrong oob msg"); - _assert(all_links[i].sock.AtMark() != 1, "should already read past mark"); + assert_(oob_mark == kResetMark, "wrong oob msg"); + assert_(all_links[i].sock.AtMark() != 1, "should already read past mark"); } else { - _assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); + assert_(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); } // send out ack char ack = kResetAck; @@ -664,9 +665,9 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { if (len == 0) { all_links[i].sock.Close(); continue; } else if (len > 0) { - _assert(ack == kResetAck, "wrong Ack MSG"); + assert_(ack == kResetAck, "wrong Ack MSG"); } else { - _assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); + assert_(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); } // set back to nonblock mode all_links[i].sock.SetNonBlock(true); @@ -685,15 +686,15 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { * \return true if err_type is kSuccess, false otherwise */ bool AllreduceRobust::CheckAndRecover(ReturnType err_type) { - shutdown_timeout = err_type == kSuccess; + shutdown_timeout_ = err_type == kSuccess; if (err_type == kSuccess) return true; - _assert(err_link != NULL, "must know the error link"); - recover_counter += 1; + assert_(err_link != nullptr, "must know the error link"); + recover_counter_ += 1; // async launch timeout task if enable_rabit_timeout is set - if (rabit_timeout && !rabit_timeout_task.valid()) { + if (rabit_timeout && !rabit_timeout_task_.valid()) { utils::Printf("[EXPERIMENTAL] timeout thread expires in %d second(s)\n", timeout_sec); - rabit_timeout_task = std::async(std::launch::async, [=]() { + rabit_timeout_task_ = std::async(std::launch::async, [=]() { if (rabit_debug) { utils::Printf("[%d] timeout thread %ld starts\n", rank, std::this_thread::get_id()); @@ -702,7 +703,7 @@ bool AllreduceRobust::CheckAndRecover(ReturnType err_type) { // check if rabit recovered every 100ms while (time++ < 10 * timeout_sec) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); - if (shutdown_timeout.load()) { + if (shutdown_timeout_.load()) { if (rabit_debug) { utils::Printf("[%d] timeout task thread %ld exits\n", rank, std::this_thread::get_id()); @@ -710,13 +711,13 @@ bool AllreduceRobust::CheckAndRecover(ReturnType err_type) { return true; } } - _error("[%d] exit due to time out %d s\n", rank, timeout_sec); + error_("[%d] exit due to time out %d s\n", rank, timeout_sec); return false; }); } // simple way, shutdown all links - for (size_t i = 0; i < all_links.size(); ++i) { - if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close(); + for (auto & all_link : all_links) { + if (!all_link.sock.BadSocket()) all_link.sock.Close(); } // smooth out traffic to tracker std::this_thread::sleep_for(std::chrono::milliseconds(10*rank)); @@ -836,8 +837,8 @@ AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role, // set p_req_in (*p_req_in)[i] = (req_in[i] != 0); if (req_out[i] != 0) { - _assert(req_in[i] == 0, "cannot get and receive request"); - _assert(static_cast(i) == best_link, "request result inconsistent"); + assert_(req_in[i] == 0, "cannot get and receive request"); + assert_(static_cast(i) == best_link, "request result inconsistent"); } } *p_recvlink = best_link; @@ -866,21 +867,21 @@ AllreduceRobust::TryRecoverData(RecoverType role, const std::vector &req_in) { RefLinkVector &links = tree_links; // no need to run recovery for zero size messages - if (links.size() == 0 || size == 0) return kSuccess; - _assert(req_in.size() == links.size(), "TryRecoverData"); - const int nlink = static_cast(links.size()); + if (links.Size() == 0 || size == 0) return kSuccess; + assert_(req_in.size() == links.Size(), "TryRecoverData"); + const int nlink = static_cast(links.Size()); { bool req_data = role == kRequestData; for (int i = 0; i < nlink; ++i) { if (req_in[i]) { - _assert(i != recv_link, "TryDecideRouting"); + assert_(i != recv_link, "TryDecideRouting"); req_data = true; } } // do not need to provide data or receive data, directly exit if (!req_data) return kSuccess; } - _assert(recv_link >= 0 || role == kHaveData, "recv_link must be active"); + assert_(recv_link >= 0 || role == kHaveData, "recv_link must be active"); if (role == kPassData) { links[recv_link].InitBuffer(1, size, reduce_buffer_size); } @@ -947,7 +948,7 @@ AllreduceRobust::TryRecoverData(RecoverType role, for (int i = 0; i < nlink; ++i) { if (req_in[i]) min_write = std::min(links[i].size_write, min_write); } - _assert(min_write <= links[pid].size_read, "boundary check"); + assert_(min_write <= links[pid].size_read, "boundary check"); ReturnType ret = links[pid].ReadToRingBuffer(min_write, size); if (ret != kSuccess) { return ReportError(&links[pid], ret); @@ -981,10 +982,10 @@ AllreduceRobust::ReturnType AllreduceRobust::TryRestoreCache(bool requester, const int min_seq, const int max_seq) { // clear requester and rebuild from those with most cache entries if (requester) { - _assert(cur_cache_seq <= max_seq, "requester is expected to have fewer cache entries"); - cachebuf.Clear(); - lookupbuf.Clear(); - cur_cache_seq = 0; + assert_(cur_cache_seq_ <= max_seq, "requester is expected to have fewer cache entries"); + cachebuf_.Clear(); + lookupbuf_.Clear(); + cur_cache_seq_ = 0; } RecoverType role = requester ? kRequestData : kHaveData; size_t size = 1; @@ -998,23 +999,23 @@ AllreduceRobust::ReturnType AllreduceRobust::TryRestoreCache(bool requester, for (int i = 0; i < max_seq; i++) { // restore lookup map size_t cache_size = 0; - void* key = lookupbuf.Query(i, &cache_size); + void* key = lookupbuf_.Query(i, &cache_size); ret = TryRecoverData(role, &cache_size, sizeof(size_t), recv_link, req_in); if (ret != kSuccess) return ret; if (requester) { - key = lookupbuf.AllocTemp(cache_size, 1); - lookupbuf.PushTemp(i, cache_size, 1); + key = lookupbuf_.AllocTemp(cache_size, 1); + lookupbuf_.PushTemp(i, cache_size, 1); } ret = TryRecoverData(role, key, cache_size, recv_link, req_in); if (ret != kSuccess) return ret; // restore cache content cache_size = 0; - void* buf = cachebuf.Query(i, &cache_size); + void* buf = cachebuf_.Query(i, &cache_size); ret = TryRecoverData(role, &cache_size, sizeof(size_t), recv_link, req_in); if (requester) { - buf = cachebuf.AllocTemp(cache_size, 1); - cachebuf.PushTemp(i, cache_size, 1); - cur_cache_seq +=1; + buf = cachebuf_.AllocTemp(cache_size, 1); + cachebuf_.PushTemp(i, cache_size, 1); + cur_cache_seq_ +=1; } ret = TryRecoverData(role, buf, cache_size, recv_link, req_in); if (ret != kSuccess) return ret; @@ -1038,20 +1039,20 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) { // check in local data RecoverType role = requester ? kRequestData : kHaveData; ReturnType succ; - if (num_local_replica != 0) { + if (num_local_replica_ != 0) { if (requester) { // clear existing history, if any, before load - local_rptr[local_chkpt_version].clear(); - local_chkpt[local_chkpt_version].clear(); + local_rptr_[local_chkpt_version_].clear(); + local_chkpt_[local_chkpt_version_].clear(); } // recover local checkpoint - succ = TryRecoverLocalState(&local_rptr[local_chkpt_version], - &local_chkpt[local_chkpt_version]); + succ = TryRecoverLocalState(&local_rptr_[local_chkpt_version_], + &local_chkpt_[local_chkpt_version_]); if (succ != kSuccess) return succ; - int nlocal = std::max(static_cast(local_rptr[local_chkpt_version].size()) - 1, 0); + int nlocal = std::max(static_cast(local_rptr_[local_chkpt_version_].size()) - 1, 0); // check if everyone is OK unsigned state = 0; - if (nlocal == num_local_replica + 1) { + if (nlocal == num_local_replica_ + 1) { // complete recovery state = 1; } else if (nlocal == 0) { @@ -1067,24 +1068,24 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) { "LoadCheckPoint: too many nodes fails, cannot recover local state"); } // do call save model if the checkpoint was lazy - if (role == kHaveData && global_lazycheck != NULL) { - global_checkpoint.resize(0); - utils::MemoryBufferStream fs(&global_checkpoint); + if (role == kHaveData && global_lazycheck_ != nullptr) { + global_checkpoint_.resize(0); + utils::MemoryBufferStream fs(&global_checkpoint_); fs.Write(&version_number, sizeof(version_number)); - global_lazycheck->Save(&fs); - global_lazycheck = NULL; + global_lazycheck_->Save(&fs); + global_lazycheck_ = nullptr; } // recover global checkpoint - size_t size = this->global_checkpoint.length(); + size_t size = this->global_checkpoint_.length(); int recv_link; std::vector req_in; succ = TryDecideRouting(role, &size, &recv_link, &req_in); if (succ != kSuccess) return succ; if (role == kRequestData) { - global_checkpoint.resize(size); + global_checkpoint_.resize(size); } if (size == 0) return kSuccess; - return TryRecoverData(role, BeginPtr(global_checkpoint), size, recv_link, req_in); + return TryRecoverData(role, BeginPtr(global_checkpoint_), size, recv_link, req_in); } /*! * \brief try to get the result of operation specified by seqno @@ -1107,19 +1108,19 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re if (seqno == ActionSummary::kLocalCheckAck) return kSuccess; if (seqno == ActionSummary::kLocalCheckPoint) { // new version of local model - int new_version = !local_chkpt_version; - int nlocal = std::max(static_cast(local_rptr[new_version].size()) - 1, 0); + int new_version = !local_chkpt_version_; + int nlocal = std::max(static_cast(local_rptr_[new_version].size()) - 1, 0); // if we goes to this place, use must have already setup the state once - _assert(nlocal == 1 || nlocal == num_local_replica + 1, + assert_(nlocal == 1 || nlocal == num_local_replica_ + 1, "TryGetResult::Checkpoint"); - return TryRecoverLocalState(&local_rptr[new_version], &local_chkpt[new_version]); + return TryRecoverLocalState(&local_rptr_[new_version], &local_chkpt_[new_version]); } // handles normal data recovery RecoverType role; if (!requester) { - sendrecvbuf = resbuf.Query(seqno, &size); - role = sendrecvbuf != NULL ? kHaveData : kPassData; + sendrecvbuf = resbuf_.Query(seqno, &size); + role = sendrecvbuf != nullptr ? kHaveData : kPassData; } else { role = kRequestData; } @@ -1160,13 +1161,13 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno, // kLoadBootstrapCache should be treated similar as allreduce // when loadcheck/check/checkack runs in other nodes if (flag != 0 && flag != ActionSummary::kLoadBootstrapCache) { - _assert(seqno == ActionSummary::kSpecialOp, "must only set seqno for normal operations"); + assert_(seqno == ActionSummary::kSpecialOp, "must only set seqno for normal operations"); } std::string msg = std::string(caller) + " pass negative seqno " + std::to_string(seqno) + " flag " + std::to_string(flag) + " version " + std::to_string(version_number); - _assert(seqno >=0, msg.c_str()); + assert_(seqno >=0, msg.c_str()); ActionSummary req(flag, flag, seqno, cache_seqno); @@ -1177,33 +1178,33 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno, // get the reduced action if (!CheckAndRecover(TryAllreduce(&act, sizeof(act), 1, ActionSummary::Reducer))) continue; - if (act.check_ack()) { - if (act.check_point()) { + if (act.CheckAck()) { + if (act.CheckPoint()) { // if we also have check_point, do check point first - _assert(!act.diff_seq(), + assert_(!act.DiffSeq(), "check ack & check pt cannot occur together with normal ops"); // if we requested checkpoint, we are free to go - if (req.check_point()) return true; - } else if (act.load_check()) { + if (req.CheckPoint()) return true; + } else if (act.LoadCheck()) { // if there is only check_ack and load_check, do load_check - if (!CheckAndRecover(TryLoadCheckPoint(req.load_check()))) continue; + if (!CheckAndRecover(TryLoadCheckPoint(req.LoadCheck()))) continue; // if requested load check, then misson complete - if (req.load_check()) return true; + if (req.LoadCheck()) return true; } else { // there is no check point and no load check, execute check ack - if (req.check_ack()) return true; + if (req.CheckAck()) return true; } // if execute to this point // this means the action requested has not been completed // try next round } else { - if (act.check_point()) { - if (act.diff_seq()) { - _assert(act.seqno() != ActionSummary::kSpecialOp, "min seq bug"); + if (act.CheckPoint()) { + if (act.DiffSeq()) { + assert_(act.Seqno() != ActionSummary::kSpecialOp, "min seq bug"); // print checkpoint consensus flag if user turn on debug if (rabit_debug) { - req.print_flags(rank, "checkpoint req"); - act.print_flags(rank, "checkpoint act"); + req.PrintFlags(rank, "checkpoint req"); + act.PrintFlags(rank, "checkpoint act"); } /* * Chen Qin @@ -1219,82 +1220,82 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno, * after catch up to checkpoint n, diff_seq will be false * */ // assume requester is falling behind - bool requester = req.seqno() == act.seqno(); + bool requester = req.Seqno() == act.Seqno(); // if not load cache - if (!act.load_cache()) { - if (act.seqno() > 0) { + if (!act.LoadCache()) { + if (act.Seqno() > 0) { if (!requester) { - _assert(req.check_point(), "checkpoint node should be KHaveData role"); - buf = resbuf.Query(act.seqno(), &size); - _assert(buf != NULL, "buf should have data from resbuf"); - _assert(size > 0, "buf size should be greater than 0"); + assert_(req.CheckPoint(), "checkpoint node should be KHaveData role"); + buf = resbuf_.Query(act.Seqno(), &size); + assert_(buf != nullptr, "buf should have data from resbuf"); + assert_(size > 0, "buf size should be greater than 0"); } - if (!CheckAndRecover(TryGetResult(buf, size, act.seqno(), requester))) continue; + if (!CheckAndRecover(TryGetResult(buf, size, act.Seqno(), requester))) continue; } } else { // cache seq no should be smaller than kSpecialOp - _assert(act.seqno(SeqType::kCache) != ActionSummary::kSpecialOp, + assert_(act.Seqno(SeqType::kCache) != ActionSummary::kSpecialOp, "checkpoint with kSpecialOp"); - int max_cache_seq = cur_cache_seq; + int max_cache_seq = cur_cache_seq_; if (TryAllreduce(&max_cache_seq, sizeof(max_cache_seq), 1, op::Reducer) != kSuccess) continue; - if (TryRestoreCache(req.load_cache(), act.seqno(), max_cache_seq) + if (TryRestoreCache(req.LoadCache(), act.Seqno(), max_cache_seq) != kSuccess) continue; } if (requester) return true; } else { // no difference in seq no, means we are free to check point - if (req.check_point()) return true; + if (req.CheckPoint()) return true; } } else { // no check point - if (act.load_check()) { + if (act.LoadCheck()) { // all the nodes called load_check, this is an incomplete action - if (!act.diff_seq()) return false; + if (!act.DiffSeq()) return false; // load check have higher priority, do load_check - if (!CheckAndRecover(TryLoadCheckPoint(req.load_check()))) continue; + if (!CheckAndRecover(TryLoadCheckPoint(req.LoadCheck()))) continue; // if requested load check, then misson complete - if (req.load_check()) return true; + if (req.LoadCheck()) return true; } else { // run all nodes in a isolated cache restore logic - if (act.load_cache()) { + if (act.LoadCache()) { // print checkpoint consensus flag if user turn on debug if (rabit_debug) { - req.print_flags(rank, "loadcache req"); - act.print_flags(rank, "loadcache act"); + req.PrintFlags(rank, "loadcache req"); + act.PrintFlags(rank, "loadcache act"); } // load cache should not running in parralel with other states - _assert(!act.load_check(), + assert_(!act.LoadCheck(), "load cache state expect no nodes doing load checkpoint"); - _assert(!act.check_point() , + assert_(!act.CheckPoint() , "load cache state expect no nodes doing checkpoint"); - _assert(!act.check_ack(), + assert_(!act.CheckAck(), "load cache state expect no nodes doing checkpoint ack"); // if all nodes are requester in load cache, skip - if (act.load_cache(SeqType::kCache)) return false; + if (act.LoadCache(SeqType::kCache)) return false; // bootstrap cache always restore before loadcheckpoint // requester always have seq diff with non requester - if (act.diff_seq()) { + if (act.DiffSeq()) { // restore cache failed, retry from what's left - if (TryRestoreCache(req.load_cache(), act.seqno(), act.seqno(SeqType::kCache)) + if (TryRestoreCache(req.LoadCache(), act.Seqno(), act.Seqno(SeqType::kCache)) != kSuccess) continue; } // if requested load cache, then mission complete - if (req.load_cache()) return true; + if (req.LoadCache()) return true; continue; } // assert no req with load cache set goes into seq catch up - _assert(!req.load_cache(), "load cache not interacte with rest states"); + assert_(!req.LoadCache(), "load cache not interacte with rest states"); // no special flags, no checkpoint, check ack, load_check - _assert(act.seqno() != ActionSummary::kSpecialOp, "min seq bug"); - if (act.diff_seq()) { - bool requester = req.seqno() == act.seqno(); - if (!CheckAndRecover(TryGetResult(buf, size, act.seqno(), requester))) continue; + assert_(act.Seqno() != ActionSummary::kSpecialOp, "min seq bug"); + if (act.DiffSeq()) { + bool requester = req.Seqno() == act.Seqno(); + if (!CheckAndRecover(TryGetResult(buf, size, act.Seqno(), requester))) continue; if (requester) return true; } else { // all the request is same, @@ -1306,7 +1307,7 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno, // something is still incomplete try next round } } - _assert(false, "RecoverExec: should not reach here"); + assert_(false, "RecoverExec: should not reach here"); return true; } /*! @@ -1329,18 +1330,18 @@ AllreduceRobust::ReturnType AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, std::string *p_local_chkpt) { // if there is no local replica, we can do nothing - if (num_local_replica == 0) return kSuccess; + if (num_local_replica_ == 0) return kSuccess; std::vector &rptr = *p_local_rptr; std::string &chkpt = *p_local_chkpt; if (rptr.size() == 0) { rptr.push_back(0); - _assert(chkpt.length() == 0, "local chkpt space inconsistent"); + assert_(chkpt.length() == 0, "local chkpt space inconsistent"); } - const int n = num_local_replica; + const int n = num_local_replica_; { // backward passing, passing state in backward direction of the ring const int nlocal = static_cast(rptr.size() - 1); - _assert(nlocal <= n + 1, "invalid local replica"); + assert_(nlocal <= n + 1, "invalid local replica"); std::vector msg_back(n + 1); msg_back[0] = nlocal; // backward passing one hop the request @@ -1394,7 +1395,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, { // forward passing, passing state in forward direction of the ring const int nlocal = static_cast(rptr.size() - 1); - _assert(nlocal <= n + 1, "invalid local replica"); + assert_(nlocal <= n + 1, "invalid local replica"); std::vector msg_forward(n + 1); msg_forward[0] = nlocal; // backward passing one hop the request @@ -1476,12 +1477,12 @@ AllreduceRobust::ReturnType AllreduceRobust::TryCheckinLocalState(std::vector *p_local_rptr, std::string *p_local_chkpt) { // if there is no local replica, we can do nothing - if (num_local_replica == 0) return kSuccess; + if (num_local_replica_ == 0) return kSuccess; std::vector &rptr = *p_local_rptr; std::string &chkpt = *p_local_chkpt; - _assert(rptr.size() == 2, + assert_(rptr.size() == 2, "TryCheckinLocalState must have exactly 1 state"); - const int n = num_local_replica; + const int n = num_local_replica_; std::vector sizes(n + 1); sizes[0] = rptr[1] - rptr[0]; ReturnType succ; @@ -1534,11 +1535,11 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_, size_t write_end, LinkRecord *read_link, LinkRecord *write_link) { - if (read_link == NULL || write_link == NULL || read_end == 0) return kSuccess; - _assert(write_end <= read_end, + if (read_link == nullptr || write_link == nullptr || read_end == 0) return kSuccess; + assert_(write_end <= read_end, "RingPassing: boundary check1"); - _assert(read_ptr <= read_end, "RingPassing: boundary check2"); - _assert(write_ptr <= write_end, "RingPassing: boundary check3"); + assert_(read_ptr <= read_end, "RingPassing: boundary check2"); + assert_(write_ptr <= write_end, "RingPassing: boundary check3"); // take reference LinkRecord &prev = *read_link, &next = *write_link; // send recv buffer diff --git a/rabit/src/allreduce_robust.h b/rabit/src/allreduce_robust.h index a4bee7c58..02bd353ea 100644 --- a/rabit/src/allreduce_robust.h +++ b/rabit/src/allreduce_robust.h @@ -22,18 +22,18 @@ namespace engine { /*! \brief implementation of fault tolerant all reduce engine */ class AllreduceRobust : public AllreduceBase { public: - AllreduceRobust(void); - virtual ~AllreduceRobust(void) {} + AllreduceRobust(); + ~AllreduceRobust() override = default; // initialize the manager - virtual bool Init(int argc, char* argv[]); + bool Init(int argc, char* argv[]) override; /*! \brief shutdown the engine */ - virtual bool Shutdown(void); + bool Shutdown() override; /*! * \brief set parameters to the engine * \param name parameter name * \param val parameter value */ - virtual void SetParam(const char *name, const char *val); + void SetParam(const char *name, const char *val) override; /*! * \brief perform immutable local bootstrap cache insertion * \param key unique cache key @@ -67,13 +67,10 @@ class AllreduceRobust : public AllreduceBase { * \param _line caller line number used to generate unique cache key * \param _caller caller function name used to generate unique cache key */ - virtual void Allgather(void *sendrecvbuf_, size_t total_size, - size_t slice_begin, - size_t slice_end, - size_t size_prev_slice, - const char* _file = _FILE, - const int _line = _LINE, - const char* _caller = _CALLER); + void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin, + size_t slice_end, size_t size_prev_slice, + const char *_file = _FILE, const int _line = _LINE, + const char *_caller = _CALLER) override; /*! * \brief perform in-place allreduce, on sendrecvbuf * this function is NOT thread-safe @@ -90,15 +87,11 @@ class AllreduceRobust : public AllreduceBase { * \param _line caller line number used to generate unique cache key * \param _caller caller function name used to generate unique cache key */ - virtual void Allreduce(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer, - PreprocFunction prepare_fun = NULL, - void *prepare_arg = NULL, - const char* _file = _FILE, - const int _line = _LINE, - const char* _caller = _CALLER); + void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, + ReduceFunction reducer, PreprocFunction prepare_fun = nullptr, + void *prepare_arg = nullptr, const char *_file = _FILE, + const int _line = _LINE, + const char *_caller = _CALLER) override; /*! * \brief broadcast data from root to all nodes * \param sendrecvbuf_ buffer for both sending and recving data @@ -108,10 +101,9 @@ class AllreduceRobust : public AllreduceBase { * \param _line caller line number used to generate unique cache key * \param _caller caller function name used to generate unique cache key */ - virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root, - const char* _file = _FILE, - const int _line = _LINE, - const char* _caller = _CALLER); + void Broadcast(void *sendrecvbuf_, size_t total_size, int root, + const char *_file = _FILE, const int _line = _LINE, + const char *_caller = _CALLER) override; /*! * \brief load latest check point * \param global_model pointer to the globally shared model/state @@ -134,8 +126,8 @@ class AllreduceRobust : public AllreduceBase { * * \sa CheckPoint, VersionNumber */ - virtual int LoadCheckPoint(Serializable *global_model, - Serializable *local_model = NULL); + int LoadCheckPoint(Serializable *global_model, + Serializable *local_model = nullptr) override; /*! * \brief checkpoint the model, meaning we finished a stage of execution * every time we call check point, there is a version number which will increase by one @@ -152,9 +144,9 @@ class AllreduceRobust : public AllreduceBase { * * \sa LoadCheckPoint, VersionNumber */ - virtual void CheckPoint(const Serializable *global_model, - const Serializable *local_model = NULL) { - this->CheckPoint_(global_model, local_model, false); + void CheckPoint(const Serializable *global_model, + const Serializable *local_model = nullptr) override { + this->CheckPointImpl(global_model, local_model, false); } /*! * \brief This function can be used to replace CheckPoint for global_model only, @@ -176,18 +168,20 @@ class AllreduceRobust : public AllreduceBase { * is the same in all nodes * \sa LoadCheckPoint, CheckPoint, VersionNumber */ - virtual void LazyCheckPoint(const Serializable *global_model) { - this->CheckPoint_(global_model, NULL, true); + void LazyCheckPoint(const Serializable *global_model) override { + this->CheckPointImpl(global_model, nullptr, true); } /*! * \brief explicitly re-init everything before calling LoadCheckPoint * call this function when IEngine throw an exception out, * this function is only used for test purpose */ - virtual void InitAfterException(void) { + void InitAfterException() override { // simple way, shutdown all links - for (size_t i = 0; i < all_links.size(); ++i) { - if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close(); + for (auto& link : all_links) { + if (link.sock.BadSocket()) { + link.sock.Close(); + } } ReConnectLinks("recover"); } @@ -245,69 +239,69 @@ class AllreduceRobust : public AllreduceBase { // there are nodes request load cache static const int kLoadBootstrapCache = 16; // constructor - ActionSummary(void) {} + ActionSummary() = default; // constructor of action explicit ActionSummary(int seqno_flag, int cache_flag = 0, u_int32_t minseqno = kSpecialOp, u_int32_t maxseqno = kSpecialOp) { - seqcode = (minseqno << 5) | seqno_flag; - maxseqcode = (maxseqno << 5) | cache_flag; + seqcode_ = (minseqno << 5) | seqno_flag; + maxseqcode_ = (maxseqno << 5) | cache_flag; } // minimum number of all operations by default // maximum number of all cache operations otherwise - inline u_int32_t seqno(SeqType t = SeqType::kSeq) const { - int code = t == SeqType::kSeq ? seqcode : maxseqcode; + inline u_int32_t Seqno(SeqType t = SeqType::kSeq) const { + int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_; return code >> 5; } // whether the operation set contains a load_check - inline bool load_check(SeqType t = SeqType::kSeq) const { - int code = t == SeqType::kSeq ? seqcode : maxseqcode; + inline bool LoadCheck(SeqType t = SeqType::kSeq) const { + int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_; return (code & kLoadCheck) != 0; } // whether the operation set contains a load_cache - inline bool load_cache(SeqType t = SeqType::kSeq) const { - int code = t == SeqType::kSeq ? seqcode : maxseqcode; + inline bool LoadCache(SeqType t = SeqType::kSeq) const { + int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_; return (code & kLoadBootstrapCache) != 0; } // whether the operation set contains a check point - inline bool check_point(SeqType t = SeqType::kSeq) const { - int code = t == SeqType::kSeq ? seqcode : maxseqcode; + inline bool CheckPoint(SeqType t = SeqType::kSeq) const { + int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_; return (code & kCheckPoint) != 0; } // whether the operation set contains a check ack - inline bool check_ack(SeqType t = SeqType::kSeq) const { - int code = t == SeqType::kSeq ? seqcode : maxseqcode; + inline bool CheckAck(SeqType t = SeqType::kSeq) const { + int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_; return (code & kCheckAck) != 0; } // whether the operation set contains different sequence number - inline bool diff_seq() const { - return (seqcode & kDiffSeq) != 0; + inline bool DiffSeq() const { + return (seqcode_ & kDiffSeq) != 0; } // returns the operation flag of the result - inline int flag(SeqType t = SeqType::kSeq) const { - int code = t == SeqType::kSeq ? seqcode : maxseqcode; + inline int Flag(SeqType t = SeqType::kSeq) const { + int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_; return code & 31; } // print flags in user friendly way - inline void print_flags(int rank, std::string prefix ) { - utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n", - rank, prefix.c_str(), - seqno(), check_point(), check_ack(), load_cache(), - diff_seq(), seqno(SeqType::kCache), load_cache(SeqType::kCache)); + inline void PrintFlags(int rank, std::string prefix ) { + utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n", rank, + prefix.c_str(), Seqno(), CheckPoint(), CheckAck(), + LoadCache(), DiffSeq(), Seqno(SeqType::kCache), + LoadCache(SeqType::kCache)); } // reducer for Allreduce, get the result ActionSummary from all nodes inline static void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) { - const ActionSummary *src = (const ActionSummary*)src_; + const ActionSummary *src = static_cast(src_); ActionSummary *dst = reinterpret_cast(dst_); for (int i = 0; i < len; ++i) { - u_int32_t min_seqno = std::min(src[i].seqno(), dst[i].seqno()); - u_int32_t max_seqno = std::max(src[i].seqno(SeqType::kCache), - dst[i].seqno(SeqType::kCache)); - int action_flag = src[i].flag() | dst[i].flag(); + u_int32_t min_seqno = std::min(src[i].Seqno(), dst[i].Seqno()); + u_int32_t max_seqno = std::max(src[i].Seqno(SeqType::kCache), + dst[i].Seqno(SeqType::kCache)); + int action_flag = src[i].Flag() | dst[i].Flag(); // if any node is not requester set to 0 otherwise 1 - int role_flag = src[i].flag(SeqType::kCache) & dst[i].flag(SeqType::kCache); + int role_flag = src[i].Flag(SeqType::kCache) & dst[i].Flag(SeqType::kCache); // if seqno is different in src and destination - int seq_diff_flag = src[i].seqno() != dst[i].seqno() ? kDiffSeq : 0; + int seq_diff_flag = src[i].Seqno() != dst[i].Seqno() ? kDiffSeq : 0; // apply or to both seq diff flag as well as cache seq diff flag dst[i] = ActionSummary(action_flag | seq_diff_flag, role_flag, min_seqno, max_seqno); @@ -316,19 +310,19 @@ class AllreduceRobust : public AllreduceBase { private: // internel sequence code min of rabit seqno - u_int32_t seqcode; + u_int32_t seqcode_; // internal sequence code max of cache seqno - u_int32_t maxseqcode; + u_int32_t maxseqcode_; }; /*! \brief data structure to remember result of Bcast and Allreduce calls*/ class ResultBuffer{ public: // constructor - ResultBuffer(void) { + ResultBuffer() { this->Clear(); } // clear the existing record - inline void Clear(void) { + inline void Clear() { seqno_.clear(); size_.clear(); rptr_.clear(); rptr_.push_back(0); data_.clear(); @@ -358,12 +352,12 @@ class AllreduceRobust : public AllreduceBase { inline void* Query(int seqid, size_t *p_size) { size_t idx = std::lower_bound(seqno_.begin(), seqno_.end(), seqid) - seqno_.begin(); - if (idx == seqno_.size() || seqno_[idx] != seqid) return NULL; + if (idx == seqno_.size() || seqno_[idx] != seqid) return nullptr; *p_size = size_[idx]; return BeginPtr(data_) + rptr_[idx]; } // drop last stored result - inline void DropLast(void) { + inline void DropLast() { utils::Assert(seqno_.size() != 0, "there is nothing to be dropped"); seqno_.pop_back(); rptr_.pop_back(); @@ -371,7 +365,7 @@ class AllreduceRobust : public AllreduceBase { data_.resize(rptr_.back()); } // the sequence number of last stored result - inline int LastSeqNo(void) const { + inline int LastSeqNo() const { if (seqno_.size() == 0) return -1; return seqno_.back(); } @@ -407,9 +401,8 @@ class AllreduceRobust : public AllreduceBase { * * \sa CheckPoint, LazyCheckPoint */ - void CheckPoint_(const Serializable *global_model, - const Serializable *local_model, - bool lazy_checkpt); + void CheckPointImpl(const Serializable *global_model, + const Serializable *local_model, bool lazy_checkpt); /*! * \brief reset the all the existing links by sending Out-of-Band message marker * after this function finishes, all the messages received and sent @@ -423,7 +416,7 @@ class AllreduceRobust : public AllreduceBase { * when kSockError is returned, it simply means there are bad sockets in the links, * and some link recovery proceduer is needed */ - ReturnType TryResetLinks(void); + ReturnType TryResetLinks(); /*! * \brief if err_type indicates an error * recover links according to the error type reported @@ -619,29 +612,29 @@ o * the input state must exactly one saved state(local state of current node) size_t out_index)); //---- recovery data structure ---- // the round of result buffer, used to mode the result - int result_buffer_round; + int result_buffer_round_; // result buffer of all reduce - ResultBuffer resbuf; + ResultBuffer resbuf_; // current cached allreduce/braodcast sequence number - int cur_cache_seq; + int cur_cache_seq_; // result buffer of cached all reduce - ResultBuffer cachebuf; + ResultBuffer cachebuf_; // key of each cache entry - ResultBuffer lookupbuf; + ResultBuffer lookupbuf_; // last check point global model - std::string global_checkpoint; + std::string global_checkpoint_; // lazy checkpoint of global model - const Serializable *global_lazycheck; + const Serializable *global_lazycheck_; // number of replica for local state/model - int num_local_replica; + int num_local_replica_; // number of default local replica - int default_local_replica; + int default_local_replica_; // flag to decide whether local model is used, -1: unknown, 0: no, 1:yes - int use_local_model; + int use_local_model_; // number of replica for global state/model - int num_global_replica; + int num_global_replica_; // number of times recovery happens - int recover_counter; + int recover_counter_; // --- recovery data structure for local checkpoint // there is two version of the data structure, // at one time one version is valid and another is used as temp memory @@ -649,21 +642,21 @@ o * the input state must exactly one saved state(local state of current node) // local model is stored in CSR format(like a sparse matrices) // local_model[rptr[0]:rptr[1]] stores the model of current node // local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops - std::vector local_rptr[2]; + std::vector local_rptr_[2]; // storage for local model replicas - std::string local_chkpt[2]; + std::string local_chkpt_[2]; // version of local checkpoint can be 1 or 0 - int local_chkpt_version; + int local_chkpt_version_; // if checkpoint were loaded, used to distinguish results boostrap cache from seqno cache - bool checkpoint_loaded; + bool checkpoint_loaded_; // sidecar executing timeout task - std::future rabit_timeout_task; + std::future rabit_timeout_task_; // flag to shutdown rabit_timeout_task before timeout - std::atomic shutdown_timeout{false}; + std::atomic shutdown_timeout_{false}; // error handler - void (* _error)(const char *fmt, ...) = utils::Error; + void (* error_)(const char *fmt, ...) = utils::Error; // assert handler - void (* _assert)(bool exp, const char *fmt, ...) = utils::Assert; + void (* assert_)(bool exp, const char *fmt, ...) = utils::Assert; }; } // namespace engine } // namespace rabit diff --git a/rabit/src/c_api.cc b/rabit/src/c_api.cc index b1c8689f7..5328daf08 100644 --- a/rabit/src/c_api.cc +++ b/rabit/src/c_api.cc @@ -33,12 +33,12 @@ struct FHelper { }; template -void Allreduce_(void *sendrecvbuf_, +void Allreduce(void *sendrecvbuf_, size_t count, engine::mpi::DataType enum_dtype, void (*prepare_fun)(void *arg), void *prepare_arg) { - using namespace engine::mpi; + using namespace engine::mpi; // NOLINT switch (enum_dtype) { case kChar: rabit::Allreduce @@ -89,28 +89,28 @@ void Allreduce(void *sendrecvbuf, engine::mpi::OpType enum_op, void (*prepare_fun)(void *arg), void *prepare_arg) { - using namespace engine::mpi; + using namespace engine::mpi; // NOLINT switch (enum_op) { case kMax: - Allreduce_ + Allreduce (sendrecvbuf, count, enum_dtype, prepare_fun, prepare_arg); return; case kMin: - Allreduce_ + Allreduce (sendrecvbuf, count, enum_dtype, prepare_fun, prepare_arg); return; case kSum: - Allreduce_ + Allreduce (sendrecvbuf, count, enum_dtype, prepare_fun, prepare_arg); return; case kBitwiseOR: - Allreduce_ + Allreduce (sendrecvbuf, count, enum_dtype, prepare_fun, prepare_arg); @@ -124,7 +124,7 @@ void Allgather(void *sendrecvbuf_, size_t size_node_slice, size_t size_prev_slice, int enum_dtype) { - using namespace engine::mpi; + using namespace engine::mpi; // NOLINT size_t type_size = 0; switch (enum_dtype) { case kChar: @@ -184,7 +184,7 @@ struct ReadWrapper : public Serializable { std::string *p_str; explicit ReadWrapper(std::string *p_str) : p_str(p_str) {} - virtual void Load(Stream *fi) { + void Load(Stream *fi) override { uint64_t sz; utils::Assert(fi->Read(&sz, sizeof(sz)) != 0, "Read pickle string"); @@ -194,7 +194,7 @@ struct ReadWrapper : public Serializable { "Read pickle string"); } } - virtual void Save(Stream *fo) const { + void Save(Stream *fo) const override { utils::Error("not implemented"); } }; @@ -206,10 +206,10 @@ struct WriteWrapper : public Serializable { size_t length) : data(data), length(length) { } - virtual void Load(Stream *fi) { + void Load(Stream *fi) override { utils::Error("not implemented"); } - virtual void Save(Stream *fo) const { + void Save(Stream *fo) const override { uint64_t sz = static_cast(length); fo->Write(&sz, sizeof(sz)); fo->Write(data, length * sizeof(char)); @@ -298,8 +298,8 @@ RABIT_DLL int RabitLoadCheckPoint(char **out_global_model, ReadWrapper sl(&local_buffer); int version; - if (out_local_model == NULL) { - version = rabit::LoadCheckPoint(&sg, NULL); + if (out_local_model == nullptr) { + version = rabit::LoadCheckPoint(&sg, nullptr); *out_global_model = BeginPtr(global_buffer); *out_global_len = static_cast(global_buffer.length()); } else { @@ -317,8 +317,8 @@ RABIT_DLL void RabitCheckPoint(const char *global_model, rbt_ulong global_len, using namespace rabit::c_api; // NOLINT(*) WriteWrapper sg(global_model, global_len); WriteWrapper sl(local_model, local_len); - if (local_model == NULL) { - rabit::CheckPoint(&sg, NULL); + if (local_model == nullptr) { + rabit::CheckPoint(&sg, nullptr); } else { rabit::CheckPoint(&sg, &sl); } diff --git a/rabit/src/engine.cc b/rabit/src/engine.cc index 69d68eb30..ac7731fdd 100644 --- a/rabit/src/engine.cc +++ b/rabit/src/engine.cc @@ -7,18 +7,19 @@ * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou */ #include +#include + #include #include "rabit/internal/engine.h" #include "allreduce_base.h" #include "allreduce_robust.h" -#include "rabit/internal/thread_local.h" namespace rabit { namespace engine { // singleton sync manager #ifndef RABIT_USE_BASE #ifndef RABIT_USE_MOCK -typedef AllreduceRobust Manager; +using Manager = AllreduceRobust; #else typedef AllreduceMock Manager; #endif // RABIT_USE_MOCK @@ -31,13 +32,13 @@ struct ThreadLocalEntry { /*! \brief stores the current engine */ std::unique_ptr engine; /*! \brief whether init has been called */ - bool initialized; + bool initialized{false}; /*! \brief constructor */ - ThreadLocalEntry() : initialized(false) {} + ThreadLocalEntry() = default; }; // define the threadlocal store. -typedef ThreadLocalStore EngineThreadLocal; +using EngineThreadLocal = dmlc::ThreadLocalStore; /*! \brief intiialize the synchronization module */ bool Init(int argc, char *argv[]) { @@ -95,7 +96,7 @@ void Allgather(void *sendrecvbuf_, size_t total_size, // perform in-place allreduce, on sendrecvbuf -void Allreduce_(void *sendrecvbuf, +void Allreduce_(void *sendrecvbuf, // NOLINT size_t type_nbytes, size_t count, IEngine::ReduceFunction red, @@ -111,18 +112,15 @@ void Allreduce_(void *sendrecvbuf, } // code for reduce handle -ReduceHandle::ReduceHandle(void) - : handle_(NULL), redfunc_(NULL), htype_(NULL) { -} - -ReduceHandle::~ReduceHandle(void) {} +ReduceHandle::ReduceHandle() = default; +ReduceHandle::~ReduceHandle() = default; int ReduceHandle::TypeSize(const MPI::Datatype &dtype) { return static_cast(dtype.type_size); } void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) { - utils::Assert(redfunc_ == NULL, "cannot initialize reduce handle twice"); + utils::Assert(redfunc_ == nullptr, "cannot initialize reduce handle twice"); redfunc_ = redfunc; } @@ -133,7 +131,7 @@ void ReduceHandle::Allreduce(void *sendrecvbuf, const char* _file, const int _line, const char* _caller) { - utils::Assert(redfunc_ != NULL, "must intialize handle to call AllReduce"); + utils::Assert(redfunc_ != nullptr, "must intialize handle to call AllReduce"); GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, redfunc_, prepare_fun, prepare_arg, _file, _line, _caller); diff --git a/rabit/src/engine_empty.cc b/rabit/src/engine_empty.cc index 53ec85ee3..248203dc9 100644 --- a/rabit/src/engine_empty.cc +++ b/rabit/src/engine_empty.cc @@ -16,78 +16,68 @@ namespace engine { /*! \brief EmptyEngine */ class EmptyEngine : public IEngine { public: - EmptyEngine(void) { - version_number = 0; + EmptyEngine() { + version_number_ = 0; } - virtual void Allgather(void *sendrecvbuf_, - size_t total_size, - size_t slice_begin, - size_t slice_end, - size_t size_prev_slice, - const char* _file, - const int _line, - const char* _caller) { + void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin, + size_t slice_end, size_t size_prev_slice, const char *_file, + const int _line, const char *_caller) override { utils::Error("EmptyEngine:: Allgather is not supported"); } - virtual int GetRingPrevRank(void) const { + int GetRingPrevRank() const override { utils::Error("EmptyEngine:: GetRingPrevRank is not supported"); return -1; } - virtual void Allreduce(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer, - PreprocFunction prepare_fun, - void *prepare_arg, - const char* _file, - const int _line, - const char* _caller) { + void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, + ReduceFunction reducer, PreprocFunction prepare_fun, + void *prepare_arg, const char *_file, const int _line, + const char *_caller) override { utils::Error("EmptyEngine:: Allreduce is not supported,"\ "use Allreduce_ instead"); } - virtual void Broadcast(void *sendrecvbuf_, size_t size, int root, - const char* _file, const int _line, const char* _caller) { + void Broadcast(void *sendrecvbuf_, size_t size, int root, + const char* _file, const int _line, const char* _caller) override { } - virtual void InitAfterException(void) { + void InitAfterException() override { utils::Error("EmptyEngine is not fault tolerant"); } - virtual int LoadCheckPoint(Serializable *global_model, - Serializable *local_model = NULL) { + int LoadCheckPoint(Serializable *global_model, + Serializable *local_model = nullptr) override { return 0; } - virtual void CheckPoint(const Serializable *global_model, - const Serializable *local_model = NULL) { - version_number += 1; + void CheckPoint(const Serializable *global_model, + const Serializable *local_model = nullptr) override { + version_number_ += 1; } - virtual void LazyCheckPoint(const Serializable *global_model) { - version_number += 1; + void LazyCheckPoint(const Serializable *global_model) override { + version_number_ += 1; } - virtual int VersionNumber(void) const { - return version_number; + int VersionNumber() const override { + return version_number_; } /*! \brief get rank of current node */ - virtual int GetRank(void) const { + int GetRank() const override { return 0; } /*! \brief get total number of */ - virtual int GetWorldSize(void) const { + int GetWorldSize() const override { return 1; } /*! \brief whether it is distributed */ - virtual bool IsDistributed(void) const { + bool IsDistributed() const override { return false; } /*! \brief get the host name of current node */ - virtual std::string GetHost(void) const { + std::string GetHost() const override { return std::string(""); } - virtual void TrackerPrint(const std::string &msg) { + void TrackerPrint(const std::string &msg) override { // simply print information into the tracker utils::Printf("%s", msg.c_str()); } private: - int version_number; + int version_number_; }; // singleton sync manager @@ -98,12 +88,12 @@ bool Init(int argc, char *argv[]) { return true; } /*! \brief finalize syncrhonization module */ -bool Finalize(void) { +bool Finalize() { return true; } /*! \brief singleton method to get engine */ -IEngine *GetEngine(void) { +IEngine *GetEngine() { return &manager; } // perform in-place allreduce, on sendrecvbuf @@ -118,13 +108,12 @@ void Allreduce_(void *sendrecvbuf, const char* _file, const int _line, const char* _caller) { - if (prepare_fun != NULL) prepare_fun(prepare_arg); + if (prepare_fun != nullptr) prepare_fun(prepare_arg); } // code for reduce handle -ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) { -} -ReduceHandle::~ReduceHandle(void) {} +ReduceHandle::ReduceHandle() = default; +ReduceHandle::~ReduceHandle() = default; int ReduceHandle::TypeSize(const MPI::Datatype &dtype) { return 0; @@ -137,7 +126,7 @@ void ReduceHandle::Allreduce(void *sendrecvbuf, const char* _file, const int _line, const char* _caller) { - if (prepare_fun != NULL) prepare_fun(prepare_arg); + if (prepare_fun != nullptr) prepare_fun(prepare_arg); } } // namespace engine } // namespace rabit diff --git a/rabit/src/engine_mock.cc b/rabit/src/engine_mock.cc index f38c423d0..5c0f8505e 100644 --- a/rabit/src/engine_mock.cc +++ b/rabit/src/engine_mock.cc @@ -1,7 +1,7 @@ /*! * Copyright (c) 2014 by Contributors * \file engine_mock.cc - * \brief this is an engine implementation that will + * \brief this is an engine implementation that will * insert failures in certain call point, to test if the engine is robust to failure * \author Tianqi Chen */ @@ -12,4 +12,3 @@ #include #include "allreduce_mock.h" #include "engine.cc" -