Correct style warnings from clang-tidy for rabit. (#6095)

This commit is contained in:
Jiaming Yuan 2020-09-08 12:13:58 +08:00 committed by GitHub
parent da61d9460b
commit b0001a6e29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 736 additions and 868 deletions

View File

@ -19,7 +19,7 @@
#define _CALLER "N/A"
#endif // (defined(__GNUC__) && !defined(__clang__))
namespace MPI {
namespace MPI { // NOLINT
/*! \brief MPI data type just to be compatible with MPI reduce function*/
class Datatype;
}
@ -36,7 +36,7 @@ class IEngine {
* used to prepare the data used by AllReduce
* \param arg additional possible argument used to invoke the preprocessor
*/
typedef void (PreprocFunction) (void *arg);
typedef void (PreprocFunction) (void *arg); // NOLINT
/*!
* \brief reduce function, the same form of MPI reduce function is used,
* to be compatible with MPI interface
@ -48,11 +48,11 @@ class IEngine {
* the definition of the reduce function should be type aware
* \param dtype the data type object, to be compatible with MPI reduce
*/
typedef void (ReduceFunction) (const void *src,
typedef void (ReduceFunction) (const void *src, // NOLINT
void *dst, int count,
const MPI::Datatype &dtype);
/*! \brief virtual destructor */
virtual ~IEngine() {}
~IEngine() = default;
/*!
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
@ -96,8 +96,8 @@ class IEngine {
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL,
PreprocFunction prepare_fun = nullptr,
void *prepare_arg = nullptr,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) = 0;
@ -119,7 +119,7 @@ class IEngine {
* call this function when IEngine throws an exception,
* this function should only be used for test purposes
*/
virtual void InitAfterException(void) = 0;
virtual void InitAfterException() = 0;
/*!
* \brief loads the latest check point
* \param global_model pointer to the globally shared model/state
@ -143,7 +143,7 @@ class IEngine {
* \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL) = 0;
Serializable *local_model = nullptr) = 0;
/*!
* \brief checkpoints the model, meaning a stage of execution was finished
* every time we call check point, a version number increases by ones
@ -161,7 +161,7 @@ class IEngine {
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL) = 0;
const Serializable *local_model = nullptr) = 0;
/*!
* \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met (see detailed explanation).
@ -188,17 +188,17 @@ class IEngine {
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
virtual int VersionNumber(void) const = 0;
virtual int VersionNumber() const = 0;
/*! \brief gets rank of previous node in ring topology */
virtual int GetRingPrevRank(void) const = 0;
virtual int GetRingPrevRank() const = 0;
/*! \brief gets rank of current node */
virtual int GetRank(void) const = 0;
virtual int GetRank() const = 0;
/*! \brief gets total number of nodes */
virtual int GetWorldSize(void) const = 0;
virtual int GetWorldSize() const = 0;
/*! \brief whether we run in distribted mode */
virtual bool IsDistributed(void) const = 0;
virtual bool IsDistributed() const = 0;
/*! \brief gets the host name of the current node */
virtual std::string GetHost(void) const = 0;
virtual std::string GetHost() const = 0;
/*!
* \brief prints the msg in the tracker,
* this function can be used to communicate progress information to
@ -211,9 +211,9 @@ class IEngine {
/*! \brief initializes the engine module */
bool Init(int argc, char *argv[]);
/*! \brief finalizes the engine module */
bool Finalize(void);
bool Finalize();
/*! \brief singleton method to get engine */
IEngine *GetEngine(void);
IEngine *GetEngine();
/*! \brief namespace that contains stubs to be compatible with MPI */
namespace mpi {
@ -280,14 +280,14 @@ void Allgather(void* sendrecvbuf,
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
void Allreduce_(void *sendrecvbuf,
void Allreduce_(void *sendrecvbuf, // NOLINT
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
mpi::DataType dtype,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL,
IEngine::PreprocFunction prepare_fun = nullptr,
void *prepare_arg = nullptr,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
@ -298,9 +298,9 @@ void Allreduce_(void *sendrecvbuf,
class ReduceHandle {
public:
// constructor
ReduceHandle(void);
ReduceHandle();
// destructor
~ReduceHandle(void);
~ReduceHandle();
/*!
* \brief initialize the reduce function,
* with the type the reduce function needs to deal with
@ -323,8 +323,8 @@ class ReduceHandle {
void Allreduce(void *sendrecvbuf,
size_t type_nbytes,
size_t count,
IEngine::PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL,
IEngine::PreprocFunction prepare_fun = nullptr,
void *prepare_arg = nullptr,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
@ -333,11 +333,11 @@ class ReduceHandle {
protected:
// handle function field
void *handle_;
void *handle_ {nullptr};
// reduce function of the reducer
IEngine::ReduceFunction *redfunc_;
IEngine::ReduceFunction *redfunc_{nullptr};
// handle to the type field
void *htype_;
void *htype_{nullptr};
// the created type in 4 bytes
size_t created_type_nbytes_;
};

View File

@ -19,12 +19,12 @@
namespace rabit {
namespace utils {
/*! \brief re-use definition of dmlc::SeekStream */
typedef dmlc::SeekStream SeekStream;
using SeekStream = dmlc::SeekStream;
/*! \brief fixed size memory buffer */
struct MemoryFixSizeBuffer : public SeekStream {
public:
// similar to SEEK_END in libc
static size_t constexpr SeekEnd = std::numeric_limits<size_t>::max();
static size_t constexpr kSeekEnd = std::numeric_limits<size_t>::max();
public:
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
@ -32,31 +32,31 @@ struct MemoryFixSizeBuffer : public SeekStream {
buffer_size_(buffer_size) {
curr_ptr_ = 0;
}
virtual ~MemoryFixSizeBuffer(void) {}
virtual size_t Read(void *ptr, size_t size) {
~MemoryFixSizeBuffer() override = default;
size_t Read(void *ptr, size_t size) override {
size_t nread = std::min(buffer_size_ - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
virtual void Write(const void *ptr, size_t size) {
void Write(const void *ptr, size_t size) override {
if (size == 0) return;
utils::Assert(curr_ptr_ + size <= buffer_size_,
"write position exceed fixed buffer size");
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
curr_ptr_ += size;
}
virtual void Seek(size_t pos) {
if (pos == SeekEnd) {
void Seek(size_t pos) override {
if (pos == kSeekEnd) {
curr_ptr_ = buffer_size_;
} else {
curr_ptr_ = static_cast<size_t>(pos);
}
}
virtual size_t Tell(void) {
size_t Tell() override {
return curr_ptr_;
}
virtual bool AtEnd(void) const {
virtual bool AtEnd() const {
return curr_ptr_ == buffer_size_;
}
@ -76,8 +76,8 @@ struct MemoryBufferStream : public SeekStream {
: p_buffer_(p_buffer) {
curr_ptr_ = 0;
}
virtual ~MemoryBufferStream(void) {}
virtual size_t Read(void *ptr, size_t size) {
~MemoryBufferStream() override = default;
size_t Read(void *ptr, size_t size) override {
utils::Assert(curr_ptr_ <= p_buffer_->length(),
"read can not have position excceed buffer length");
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
@ -85,7 +85,7 @@ struct MemoryBufferStream : public SeekStream {
curr_ptr_ += nread;
return nread;
}
virtual void Write(const void *ptr, size_t size) {
void Write(const void *ptr, size_t size) override {
if (size == 0) return;
if (curr_ptr_ + size > p_buffer_->length()) {
p_buffer_->resize(curr_ptr_+size);
@ -93,13 +93,13 @@ struct MemoryBufferStream : public SeekStream {
std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
curr_ptr_ += size;
}
virtual void Seek(size_t pos) {
void Seek(size_t pos) override {
curr_ptr_ = static_cast<size_t>(pos);
}
virtual size_t Tell(void) {
size_t Tell() override {
return curr_ptr_;
}
virtual bool AtEnd(void) const {
virtual bool AtEnd() const {
return curr_ptr_ == p_buffer_->length();
}

View File

@ -19,45 +19,45 @@ namespace engine {
namespace mpi {
// template function to translate type to enum indicator
template<typename DType>
inline DataType GetType(void);
inline DataType GetType();
template<>
inline DataType GetType<char>(void) {
inline DataType GetType<char>() {
return kChar;
}
template<>
inline DataType GetType<unsigned char>(void) {
inline DataType GetType<unsigned char>() {
return kUChar;
}
template<>
inline DataType GetType<int>(void) {
inline DataType GetType<int>() {
return kInt;
}
template<>
inline DataType GetType<unsigned int>(void) { // NOLINT(*)
inline DataType GetType<unsigned int>() { // NOLINT(*)
return kUInt;
}
template<>
inline DataType GetType<long>(void) { // NOLINT(*)
inline DataType GetType<long>() { // NOLINT(*)
return kLong;
}
template<>
inline DataType GetType<unsigned long>(void) { // NOLINT(*)
inline DataType GetType<unsigned long>() { // NOLINT(*)
return kULong;
}
template<>
inline DataType GetType<float>(void) {
inline DataType GetType<float>() {
return kFloat;
}
template<>
inline DataType GetType<double>(void) {
inline DataType GetType<double>() {
return kDouble;
}
template<>
inline DataType GetType<long long>(void) { // NOLINT(*)
inline DataType GetType<long long>() { // NOLINT(*)
return kLongLong;
}
template<>
inline DataType GetType<unsigned long long>(void) { // NOLINT(*)
inline DataType GetType<unsigned long long>() { // NOLINT(*)
return kULongLong;
}
} // namespace mpi
@ -94,7 +94,7 @@ struct BitOR {
};
template<typename OP, typename DType>
inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
const DType* src = (const DType*)src_;
const DType* src = static_cast<const DType*>(src_);
DType* dst = (DType*)dst_; // NOLINT(*)
for (int i = 0; i < len; i++) {
OP::Reduce(dst[i], src[i]);
@ -107,27 +107,27 @@ inline bool Init(int argc, char *argv[]) {
return engine::Init(argc, argv);
}
// finalize the rabit engine
inline bool Finalize(void) {
inline bool Finalize() {
return engine::Finalize();
}
// get the rank of the previous worker in ring topology
inline int GetRingPrevRank(void) {
inline int GetRingPrevRank() {
return engine::GetEngine()->GetRingPrevRank();
}
// get the rank of current process
inline int GetRank(void) {
inline int GetRank() {
return engine::GetEngine()->GetRank();
}
// the the size of the world
inline int GetWorldSize(void) {
inline int GetWorldSize() {
return engine::GetEngine()->GetWorldSize();
}
// whether rabit is distributed
inline bool IsDistributed(void) {
inline bool IsDistributed() {
return engine::GetEngine()->IsDistributed();
}
// get the name of current processor
inline std::string GetProcessorName(void) {
inline std::string GetProcessorName() {
return engine::GetEngine()->GetHost();
}
// broadcast data to all other nodes from root
@ -183,7 +183,7 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
// C++11 support for lambda prepare function
#if DMLC_USE_CXX11
inline void InvokeLambda_(void *fun) {
inline void InvokeLambda(void *fun) {
(*static_cast<std::function<void()>*>(fun))();
}
template<typename OP, typename DType>
@ -193,7 +193,7 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
const int _line,
const char* _caller) {
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda_, &prepare_fun,
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda, &prepare_fun,
_file, _line, _caller);
}
@ -245,7 +245,7 @@ inline void LazyCheckPoint(const Serializable *global_model) {
engine::GetEngine()->LazyCheckPoint(global_model);
}
// return the version number of currently stored model
inline int VersionNumber(void) {
inline int VersionNumber() {
return engine::GetEngine()->VersionNumber();
}
// ---------------------------------
@ -253,7 +253,7 @@ inline int VersionNumber(void) {
// ---------------------------------
// function to perform reduction for Reducer
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
inline void ReducerSafe_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
inline void ReducerSafeImpl(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
const size_t kUnit = sizeof(DType);
const char *psrc = reinterpret_cast<const char*>(src_);
char *pdst = reinterpret_cast<char*>(dst_);
@ -269,7 +269,7 @@ inline void ReducerSafe_(const void *src_, void *dst_, int len_, const MPI::Data
}
// function to perform reduction for Reducer
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
inline void ReducerAlign_(const void *src_, void *dst_,
inline void ReducerAlignImpl(const void *src_, void *dst_,
int len_, const MPI::Datatype &dtype) {
const DType *psrc = reinterpret_cast<const DType*>(src_);
DType *pdst = reinterpret_cast<DType*>(dst_);
@ -278,12 +278,12 @@ inline void ReducerAlign_(const void *src_, void *dst_,
}
}
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
inline Reducer<DType, freduce>::Reducer(void) {
inline Reducer<DType, freduce>::Reducer() {
// it is safe to directly use handle for aligned data types
if (sizeof(DType) == 8 || sizeof(DType) == 4 || sizeof(DType) == 1) {
this->handle_.Init(ReducerAlign_<DType, freduce>, sizeof(DType));
this->handle_.Init(ReducerAlignImpl<DType, freduce>, sizeof(DType));
} else {
this->handle_.Init(ReducerSafe_<DType, freduce>, sizeof(DType));
this->handle_.Init(ReducerSafeImpl<DType, freduce>, sizeof(DType));
}
}
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
@ -298,8 +298,8 @@ inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
}
// function to perform reduction for SerializeReducer
template<typename DType>
inline void SerializeReducerFunc_(const void *src_, void *dst_,
int len_, const MPI::Datatype &dtype) {
inline void SerializeReducerFuncImpl(const void *src_, void *dst_,
int len_, const MPI::Datatype &dtype) {
int nbytes = engine::ReduceHandle::TypeSize(dtype);
// temp space
for (int i = 0; i < len_; ++i) {
@ -315,8 +315,8 @@ inline void SerializeReducerFunc_(const void *src_, void *dst_,
}
}
template<typename DType>
inline SerializeReducer<DType>::SerializeReducer(void) {
handle_.Init(SerializeReducerFunc_<DType>, sizeof(DType));
inline SerializeReducer<DType>::SerializeReducer() {
handle_.Init(SerializeReducerFuncImpl<DType>, sizeof(DType));
}
// closure to call Allreduce
template<typename DType>
@ -327,8 +327,8 @@ struct SerializeReduceClosure {
void *prepare_arg;
std::string *p_buffer;
// invoke the closure
inline void Run(void) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
inline void Run() {
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
for (size_t i = 0; i < count; ++i) {
utils::MemoryFixSizeBuffer fs(BeginPtr(*p_buffer) + i * max_nbyte, max_nbyte);
sendrecvobj[i].Save(fs);
@ -368,7 +368,7 @@ inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
const char* _file,
const int _line,
const char* _caller) {
this->Allreduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun,
this->Allreduce(sendrecvbuf, count, InvokeLambda, &prepare_fun,
_file, _line, _caller);
}
template<typename DType>
@ -378,7 +378,7 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
const char* _file,
const int _line,
const char* _caller) {
this->Allreduce(sendrecvobj, max_nbytes, count, InvokeLambda_, &prepare_fun,
this->Allreduce(sendrecvobj, max_nbytes, count, InvokeLambda, &prepare_fun,
_file, _line, _caller);
}
#endif // DMLC_USE_CXX11

View File

@ -15,7 +15,7 @@
#else
#include <fcntl.h>
#include <netdb.h>
#include <errno.h>
#include <cerrno>
#include <unistd.h>
#include <arpa/inet.h>
#include <netinet/in.h>
@ -39,9 +39,9 @@ static inline int poll(struct pollfd *pfd, int nfds,
int timeout) { return WSAPoll ( pfd, nfds, timeout ); }
#else
#include <sys/poll.h>
typedef int SOCKET;
typedef size_t sock_size_t;
const int INVALID_SOCKET = -1;
using SOCKET = int;
using sock_size_t = size_t; // NOLINT
const int kInvalidSocket = -1;
#endif // defined(_WIN32)
namespace rabit {
@ -50,11 +50,11 @@ namespace utils {
struct SockAddr {
sockaddr_in addr;
// constructor
SockAddr(void) {}
SockAddr() = default;
SockAddr(const char *url, int port) {
this->Set(url, port);
}
inline static std::string GetHostName(void) {
inline static std::string GetHostName() {
std::string buf; buf.resize(256);
utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
return std::string(buf.c_str());
@ -69,20 +69,20 @@ struct SockAddr {
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET;
hints.ai_protocol = SOCK_STREAM;
addrinfo *res = NULL;
int sig = getaddrinfo(host, NULL, &hints, &res);
Check(sig == 0 && res != NULL, "cannot obtain address of %s", host);
addrinfo *res = nullptr;
int sig = getaddrinfo(host, nullptr, &hints, &res);
Check(sig == 0 && res != nullptr, "cannot obtain address of %s", host);
Check(res->ai_family == AF_INET, "Does not support IPv6");
memcpy(&addr, res->ai_addr, res->ai_addrlen);
addr.sin_port = htons(port);
freeaddrinfo(res);
}
/*! \brief return port of the address*/
inline int port(void) const {
inline int Port() const {
return ntohs(addr.sin_port);
}
/*! \return a string representation of the address */
inline std::string AddrStr(void) const {
inline std::string AddrStr() const {
std::string buf; buf.resize(256);
#ifdef _WIN32
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
@ -91,7 +91,7 @@ struct SockAddr {
const char *s = inet_ntop(AF_INET, &addr.sin_addr,
&buf[0], buf.length());
#endif // _WIN32
Assert(s != NULL, "cannot decode address");
Assert(s != nullptr, "cannot decode address");
return std::string(s);
}
};
@ -104,13 +104,13 @@ class Socket {
/*! \brief the file descriptor of socket */
SOCKET sockfd;
// default conversion to int
inline operator SOCKET() const {
operator SOCKET() const { // NOLINT
return sockfd;
}
/*!
* \return last error of socket operation
*/
inline static int GetLastError(void) {
inline static int GetLastError() {
#ifdef _WIN32
return WSAGetLastError();
#else
@ -118,7 +118,7 @@ class Socket {
#endif // _WIN32
}
/*! \return whether last error was would block */
inline static bool LastErrorWouldBlock(void) {
inline static bool LastErrorWouldBlock() {
int errsv = GetLastError();
#ifdef _WIN32
return errsv == WSAEWOULDBLOCK;
@ -130,7 +130,7 @@ class Socket {
* \brief start up the socket module
* call this before using the sockets
*/
inline static void Startup(void) {
inline static void Startup() {
#ifdef _WIN32
WSADATA wsa_data;
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
@ -145,7 +145,7 @@ class Socket {
/*!
* \brief shutdown the socket module after use, all sockets need to be closed
*/
inline static void Finalize(void) {
inline static void Finalize() {
#ifdef _WIN32
WSACleanup();
#endif // _WIN32
@ -214,7 +214,7 @@ class Socket {
return -1;
}
/*! \brief get last error code if any */
inline int GetSockError(void) const {
inline int GetSockError() const {
int error = 0;
socklen_t len = sizeof(error);
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
@ -224,25 +224,25 @@ class Socket {
return error;
}
/*! \brief check if anything bad happens */
inline bool BadSocket(void) const {
inline bool BadSocket() const {
if (IsClosed()) return true;
int err = GetSockError();
if (err == EBADF || err == EINTR) return true;
return false;
}
/*! \brief check if socket is already closed */
inline bool IsClosed(void) const {
return sockfd == INVALID_SOCKET;
inline bool IsClosed() const {
return sockfd == kInvalidSocket;
}
/*! \brief close the socket */
inline void Close(void) {
if (sockfd != INVALID_SOCKET) {
inline void Close() {
if (sockfd != kInvalidSocket) {
#ifdef _WIN32
closesocket(sockfd);
#else
close(sockfd);
#endif
sockfd = INVALID_SOCKET;
sockfd = kInvalidSocket;
} else {
Error("Socket::Close double close the socket or close without create");
}
@ -268,7 +268,7 @@ class Socket {
class TCPSocket : public Socket{
public:
// constructor
TCPSocket(void) : Socket(INVALID_SOCKET) {
TCPSocket() : Socket(kInvalidSocket) {
}
explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) {
}
@ -297,7 +297,7 @@ class TCPSocket : public Socket{
*/
inline void Create(int af = PF_INET) {
sockfd = socket(PF_INET, SOCK_STREAM, 0);
if (sockfd == INVALID_SOCKET) {
if (sockfd == kInvalidSocket) {
Socket::Error("Create");
}
}
@ -309,9 +309,9 @@ class TCPSocket : public Socket{
listen(sockfd, backlog);
}
/*! \brief get a new connection */
TCPSocket Accept(void) {
SOCKET newfd = accept(sockfd, NULL, NULL);
if (newfd == INVALID_SOCKET) {
TCPSocket Accept() {
SOCKET newfd = accept(sockfd, nullptr, nullptr);
if (newfd == kInvalidSocket) {
Socket::Error("Accept");
}
return TCPSocket(newfd);
@ -320,7 +320,7 @@ class TCPSocket : public Socket{
* \brief decide whether the socket is at OOB mark
* \return 1 if at mark, 0 if not, -1 if an error occured
*/
inline int AtMark(void) const {
inline int AtMark() const {
#ifdef _WIN32
unsigned long atmark; // NOLINT(*)
if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;

View File

@ -1,87 +0,0 @@
/*!
* Copyright (c) 2015 by Contributors
* \file thread_local.h
* \brief Common utility for thread local storage.
*/
#ifndef RABIT_INTERNAL_THREAD_LOCAL_H_
#define RABIT_INTERNAL_THREAD_LOCAL_H_
#include "../include/dmlc/base.h"
#if DMLC_ENABLE_STD_THREAD
#include <mutex>
#endif // DMLC_ENABLE_STD_THREAD
#include <memory>
#include <vector>
namespace rabit {
// macro hanlding for threadlocal variables
#ifdef __GNUC__
#define MX_TREAD_LOCAL __thread
#elif __STDC_VERSION__ >= 201112L
#define MX_TREAD_LOCAL _Thread_local
#elif defined(_MSC_VER)
#define MX_TREAD_LOCAL __declspec(thread)
#endif // __GNUC__
#ifndef MX_TREAD_LOCAL
#message("Warning: Threadlocal is not enabled");
#endif // MX_TREAD_LOCAL
/*!
* \brief A threadlocal store to store threadlocal variables.
* Will return a thread local singleton of type T
* \tparam T the type we like to store
*/
template<typename T>
class ThreadLocalStore {
public:
/*! \return get a thread local singleton */
static T* Get() {
static MX_TREAD_LOCAL T* ptr = nullptr;
if (ptr == nullptr) {
ptr = new T();
Singleton()->RegisterDelete(ptr);
}
return ptr;
}
private:
/*! \brief constructor */
ThreadLocalStore() {}
/*! \brief destructor */
~ThreadLocalStore() {
for (size_t i = 0; i < data_.size(); ++i) {
delete data_[i];
}
}
/*! \return singleton of the store */
static ThreadLocalStore<T> *Singleton() {
static ThreadLocalStore<T> inst;
return &inst;
}
/*!
* \brief register str for internal deletion
* \param str the string pointer
*/
void RegisterDelete(T *str) {
#if DMLC_ENABLE_STD_THREAD
std::unique_lock<std::mutex> lock(mutex_);
data_.push_back(str);
lock.unlock();
#else
data_.push_back(str);
#endif // DMLC_ENABLE_STD_THREAD
}
#if DMLC_ENABLE_STD_THREAD
/*! \brief internal mutex */
std::mutex mutex_;
#endif // DMLC_ENABLE_STD_THREAD
/*!\brief internal data */
std::vector<T*> data_;
};
} // namespace rabit
#endif // RABIT_INTERNAL_THREAD_LOCAL_H_

View File

@ -6,7 +6,7 @@
*/
#ifndef RABIT_INTERNAL_TIMER_H_
#define RABIT_INTERNAL_TIMER_H_
#include <time.h>
#include <ctime>
#ifdef __MACH__
#include <mach/clock.h>
#include <mach/mach.h>
@ -18,7 +18,7 @@ namespace utils {
/*!
* \brief return time in seconds, not cross platform, avoid to use this in most places
*/
inline double GetTime(void) {
inline double GetTime() {
#ifdef __MACH__
clock_serv_t cclock;
mach_timespec_t mts;

View File

@ -8,7 +8,7 @@
#define RABIT_INTERNAL_UTILS_H_
#include <rabit/base.h>
#include <string.h>
#include <cstring>
#include <cstdio>
#include <string>
#include <cstdlib>
@ -48,15 +48,7 @@ extern "C" {
}
#endif // _MSC_VER
#ifdef _MSC_VER
typedef unsigned char uint8_t;
typedef unsigned __int16 uint16_t;
typedef unsigned __int32 uint32_t;
typedef unsigned __int64 uint64_t;
typedef __int64 int64_t;
#else
#include <inttypes.h>
#endif // _MSC_VER
#include <cinttypes>
namespace rabit {
/*! \brief namespace for helper utils of the project */
@ -184,7 +176,7 @@ inline void Error(const char *fmt, ...) {
/*! \brief replace fopen, report error when the file open fails */
inline std::FILE *FopenCheck(const char *fname, const char *flag) {
std::FILE *fp = fopen64(fname, flag);
Check(fp != NULL, "can not open file \"%s\"\n", fname);
Check(fp != nullptr, "can not open file \"%s\"\n", fname);
return fp;
}
} // namespace utils
@ -193,7 +185,7 @@ inline std::FILE *FopenCheck(const char *fname, const char *flag) {
template<typename T>
inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*)
if (vec.size() == 0) {
return NULL;
return nullptr;
} else {
return &vec[0];
}
@ -202,17 +194,17 @@ inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*)
template<typename T>
inline const T *BeginPtr(const std::vector<T> &vec) { // NOLINT(*)
if (vec.size() == 0) {
return NULL;
return nullptr;
} else {
return &vec[0];
}
}
inline char* BeginPtr(std::string &str) { // NOLINT(*)
if (str.length() == 0) return NULL;
if (str.length() == 0) return nullptr;
return &str[0];
}
inline const char* BeginPtr(const std::string &str) {
if (str.length() == 0) return NULL;
if (str.length() == 0) return nullptr;
return &str[0];
}
} // namespace rabit

View File

@ -53,12 +53,12 @@ namespace rabit {
* \brief defines stream used in rabit
* see definition of Stream in dmlc/io.h
*/
typedef dmlc::Stream Stream;
using Stream = dmlc::Stream;
/*!
* \brief defines serializable objects used in rabit
* see definition of Serializable in dmlc/io.h
*/
typedef dmlc::Serializable Serializable;
using Serializable = dmlc::Serializable;
/*!
* \brief reduction operators namespace
@ -199,8 +199,8 @@ inline void Broadcast(std::string *sendrecv_data, int root,
*/
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *) = NULL,
void *prepare_arg = NULL,
void (*prepare_fun)(void *) = nullptr,
void *prepare_arg = nullptr,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
@ -291,7 +291,7 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
* \sa CheckPoint, VersionNumber
*/
inline int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL);
Serializable *local_model = nullptr);
/*!
* \brief checkpoints the model, meaning a stage of execution has finished.
* every time we call check point, a version number will be increased by one
@ -307,7 +307,7 @@ inline int LoadCheckPoint(Serializable *global_model,
* \sa LoadCheckPoint, VersionNumber
*/
inline void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL);
const Serializable *local_model = nullptr);
/*!
* \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met (see detailed explanation).
@ -366,8 +366,8 @@ class Reducer {
* \param _caller caller function name used to generate unique cache key
*/
inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *) = NULL,
void *prepare_arg = NULL,
void (*prepare_fun)(void *) = nullptr,
void *prepare_arg = nullptr,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
@ -422,8 +422,8 @@ class SerializeReducer {
*/
inline void Allreduce(DType *sendrecvobj,
size_t max_nbyte, size_t count,
void (*prepare_fun)(void *) = NULL,
void *prepare_arg = NULL,
void (*prepare_fun)(void *) = nullptr,
void *prepare_arg = nullptr,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);

View File

@ -15,12 +15,12 @@ namespace rabit {
* \brief defines stream used in rabit
* see definition of Stream in dmlc/io.h
*/
typedef dmlc::Stream Stream;
using Stream = dmlc::Stream ;
/*!
* \brief defines serializable objects used in rabit
* see definition of Serializable in dmlc/io.h
*/
typedef dmlc::Serializable Serializable;
using Serializable = dmlc::Serializable;
} // namespace rabit
#endif // RABIT_SERIALIZABLE_H_

View File

@ -15,7 +15,7 @@
namespace rabit {
namespace engine {
// constructor
AllreduceBase::AllreduceBase(void) {
AllreduceBase::AllreduceBase() {
tracker_uri = "NULL";
tracker_port = 9000;
host_uri = "";
@ -24,7 +24,7 @@ AllreduceBase::AllreduceBase(void) {
rank = 0;
world_size = -1;
connect_retry = 5;
hadoop_mode = 0;
hadoop_mode = false;
version_number = 0;
// 32 K items
reduce_ring_mincount = 32 << 10;
@ -32,27 +32,27 @@ AllreduceBase::AllreduceBase(void) {
tree_reduce_minsize = 1 << 20;
// tracker URL
task_id = "NULL";
err_link = NULL;
err_link = nullptr;
dmlc_role = "worker";
this->SetParam("rabit_reduce_buffer", "256MB");
// setup possible enviroment variable of interest
// include dmlc support direct variables
env_vars.push_back("DMLC_TASK_ID");
env_vars.push_back("DMLC_ROLE");
env_vars.push_back("DMLC_NUM_ATTEMPT");
env_vars.push_back("DMLC_TRACKER_URI");
env_vars.push_back("DMLC_TRACKER_PORT");
env_vars.push_back("DMLC_WORKER_CONNECT_RETRY");
env_vars.emplace_back("DMLC_TASK_ID");
env_vars.emplace_back("DMLC_ROLE");
env_vars.emplace_back("DMLC_NUM_ATTEMPT");
env_vars.emplace_back("DMLC_TRACKER_URI");
env_vars.emplace_back("DMLC_TRACKER_PORT");
env_vars.emplace_back("DMLC_WORKER_CONNECT_RETRY");
}
// initialization function
bool AllreduceBase::Init(int argc, char* argv[]) {
// setup from enviroment variables
// handler to get variables from env
for (size_t i = 0; i < env_vars.size(); ++i) {
const char *value = getenv(env_vars[i].c_str());
if (value != NULL) {
this->SetParam(env_vars[i].c_str(), value);
for (auto & env_var : env_vars) {
const char *value = getenv(env_var.c_str());
if (value != nullptr) {
this->SetParam(env_var.c_str(), value);
}
}
// pass in arguments override env variable.
@ -66,35 +66,35 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
{
// handling for hadoop
const char *task_id = getenv("mapred_tip_id");
if (task_id == NULL) {
if (task_id == nullptr) {
task_id = getenv("mapreduce_task_id");
}
if (hadoop_mode) {
utils::Check(task_id != NULL,
utils::Check(task_id != nullptr,
"hadoop_mode is set but cannot find mapred_task_id");
}
if (task_id != NULL) {
if (task_id != nullptr) {
this->SetParam("rabit_task_id", task_id);
this->SetParam("rabit_hadoop_mode", "1");
}
const char *attempt_id = getenv("mapred_task_id");
if (attempt_id != 0) {
if (attempt_id != nullptr) {
const char *att = strrchr(attempt_id, '_');
int num_trial;
if (att != NULL && sscanf(att + 1, "%d", &num_trial) == 1) {
if (att != nullptr && sscanf(att + 1, "%d", &num_trial) == 1) {
this->SetParam("rabit_num_trial", att + 1);
}
}
// handling for hadoop
const char *num_task = getenv("mapred_map_tasks");
if (num_task == NULL) {
if (num_task == nullptr) {
num_task = getenv("mapreduce_job_maps");
}
if (hadoop_mode) {
utils::Check(num_task != NULL,
utils::Check(num_task != nullptr,
"hadoop_mode is set but cannot find mapred_map_tasks");
}
if (num_task != NULL) {
if (num_task != nullptr) {
this->SetParam("rabit_world_size", num_task);
}
}
@ -115,10 +115,10 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
return this->ReConnectLinks();
}
bool AllreduceBase::Shutdown(void) {
bool AllreduceBase::Shutdown() {
try {
for (size_t i = 0; i < all_links.size(); ++i) {
all_links[i].sock.Close();
for (auto & all_link : all_links) {
all_link.sock.Close();
}
all_links.clear();
tree_links.plinks.clear();
@ -208,17 +208,18 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
utils::Assert(timeout_sec >= 0, "rabit_timeout_sec should be non negative second");
}
if (!strcmp(name, "rabit_enable_tcp_no_delay")) {
if (!strcmp(val, "true"))
if (!strcmp(val, "true")) {
rabit_enable_tcp_no_delay = true;
else
} else {
rabit_enable_tcp_no_delay = false;
}
}
}
/*!
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
utils::TCPSocket AllreduceBase::ConnectTracker() const {
int magic = kMagic;
// get information from tracker
utils::TCPSocket tracker;
@ -241,7 +242,7 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
}
}
break;
} while (1);
} while (true);
using utils::Assert;
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
@ -320,19 +321,19 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
do {
// send over good links
std::vector<int> good_link;
for (size_t i = 0; i < all_links.size(); ++i) {
if (!all_links[i].sock.BadSocket()) {
good_link.push_back(static_cast<int>(all_links[i].rank));
for (auto & all_link : all_links) {
if (!all_link.sock.BadSocket()) {
good_link.push_back(static_cast<int>(all_link.rank));
} else {
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
if (!all_link.sock.IsClosed()) all_link.sock.Close();
}
}
int ngood = static_cast<int>(good_link.size());
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
"ReConnectLink failure 5");
for (size_t i = 0; i < good_link.size(); ++i) {
Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \
sizeof(good_link[i]), "ReConnectLink failure 6");
for (int & i : good_link) {
Assert(tracker.SendAll(&i, sizeof(i)) == \
sizeof(i), "ReConnectLink failure 6");
}
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
"ReConnectLink failure 7");
@ -362,11 +363,11 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
utils::Check(hrank == r.rank,
"ReConnectLink failure, link rank inconsistent");
bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == hrank) {
Assert(all_links[i].sock.IsClosed(),
for (auto & all_link : all_links) {
if (all_link.rank == hrank) {
Assert(all_link.sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock;
all_link.sock = r.sock;
match = true;
break;
}
@ -390,11 +391,11 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
"ReConnectLink failure 15");
bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == r.rank) {
utils::Assert(all_links[i].sock.IsClosed(),
for (auto & all_link : all_links) {
if (all_link.rank == r.rank) {
utils::Assert(all_link.sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock;
all_link.sock = r.sock;
match = true;
break;
}
@ -406,29 +407,29 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
// setup tree links and ring structure
tree_links.plinks.clear();
int tcpNoDelay = 1;
for (size_t i = 0; i < all_links.size(); ++i) {
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
for (auto & all_link : all_links) {
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
// set the socket to non-blocking mode, enable TCP keepalive
all_links[i].sock.SetNonBlock(true);
all_links[i].sock.SetKeepAlive(true);
all_link.sock.SetNonBlock(true);
all_link.sock.SetKeepAlive(true);
if (rabit_enable_tcp_no_delay) {
setsockopt(all_links[i].sock, IPPROTO_TCP,
setsockopt(all_link.sock, IPPROTO_TCP,
TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay));
}
if (tree_neighbors.count(all_links[i].rank) != 0) {
if (all_links[i].rank == parent_rank) {
if (tree_neighbors.count(all_link.rank) != 0) {
if (all_link.rank == parent_rank) {
parent_index = static_cast<int>(tree_links.plinks.size());
}
tree_links.plinks.push_back(&all_links[i]);
tree_links.plinks.push_back(&all_link);
}
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
if (all_links[i].rank == next_rank) ring_next = &all_links[i];
if (all_link.rank == prev_rank) ring_prev = &all_link;
if (all_link.rank == next_rank) ring_next = &all_link;
}
Assert(parent_rank == -1 || parent_index != -1,
"cannot find parent in the link");
Assert(prev_rank == -1 || ring_prev != NULL,
Assert(prev_rank == -1 || ring_prev != nullptr,
"cannot find prev ring in the link");
Assert(next_rank == -1 || ring_next != NULL,
Assert(next_rank == -1 || ring_next != nullptr,
"cannot find next ring in the link");
return true;
} catch (const std::exception& e) {
@ -479,11 +480,11 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
size_t count,
ReduceFunction reducer) {
RefLinkVector &links = tree_links;
if (links.size() == 0 || count == 0) return kSuccess;
if (links.Size() == 0 || count == 0) return kSuccess;
// total size of message
const size_t total_size = type_nbytes * count;
// number of links
const int nlink = static_cast<int>(links.size());
const int nlink = static_cast<int>(links.Size());
// send recv buffer
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
// size of space that we already performs reduce in up pass
@ -581,8 +582,9 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
// if max reduce is less than total size, we reduce multiple times of
// eachreduce size
if (max_reduce < total_size)
if (max_reduce < total_size) {
max_reduce = max_reduce - max_reduce % eachreduce;
}
// peform reduce, can be at most two rounds
while (size_up_reduce < max_reduce) {
@ -677,11 +679,11 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
AllreduceBase::ReturnType
AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
RefLinkVector &links = tree_links;
if (links.size() == 0 || total_size == 0) return kSuccess;
if (links.Size() == 0 || total_size == 0) return kSuccess;
utils::Check(root < world_size,
"Broadcast: root should be smaller than world size");
// number of links
const int nlink = static_cast<int>(links.size());
const int nlink = static_cast<int>(links.Size());
// size of space already read from data
size_t size_in = 0;
// input link, -2 means unknown yet, -1 means this is root

View File

@ -25,7 +25,7 @@
#endif // RABIT_CXXTESTDEFS_H
namespace MPI {
namespace MPI { // NOLINT
// MPI data type to be compatible with existing MPI interface
class Datatype {
public:
@ -41,12 +41,12 @@ class AllreduceBase : public IEngine {
// magic number to verify server
static const int kMagic = 0xff99;
// constant one byte out of band message to indicate error happening
AllreduceBase(void);
virtual ~AllreduceBase(void) {}
AllreduceBase();
virtual ~AllreduceBase() = default;
// initialize the manager
virtual bool Init(int argc, char* argv[]);
// shutdown the engine
virtual bool Shutdown(void);
virtual bool Shutdown();
/*!
* \brief set parameters to the engine
* \param name parameter name
@ -59,27 +59,27 @@ class AllreduceBase : public IEngine {
* the user who monitors the tracker
* \param msg message to be printed in the tracker
*/
virtual void TrackerPrint(const std::string &msg);
void TrackerPrint(const std::string &msg) override;
/*! \brief get rank of previous node in ring topology*/
virtual int GetRingPrevRank(void) const {
int GetRingPrevRank() const override {
return ring_prev->rank;
}
/*! \brief get rank */
virtual int GetRank(void) const {
int GetRank() const override {
return rank;
}
/*! \brief get rank */
virtual int GetWorldSize(void) const {
int GetWorldSize() const override {
if (world_size == -1) return 1;
return world_size;
}
/*! \brief whether is distributed or not */
virtual bool IsDistributed(void) const {
bool IsDistributed() const override {
return tracker_uri != "NULL";
}
/*! \brief get rank */
virtual std::string GetHost(void) const {
std::string GetHost() const override {
return host_uri;
}
@ -99,13 +99,10 @@ class AllreduceBase : public IEngine {
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Allgather(void *sendrecvbuf_, size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
size_t slice_end, size_t size_prev_slice,
const char *_file = _FILE, const int _line = _LINE,
const char *_caller = _CALLER) override {
if (world_size == 1 || world_size == -1) return;
utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size,
slice_begin, slice_end, size_prev_slice) == kSuccess,
@ -126,16 +123,12 @@ class AllreduceBase : public IEngine {
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
ReduceFunction reducer, PreprocFunction prepare_fun = nullptr,
void *prepare_arg = nullptr, const char *_file = _FILE,
const int _line = _LINE,
const char *_caller = _CALLER) override {
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
if (world_size == 1 || world_size == -1) return;
utils::Assert(TryAllreduce(sendrecvbuf_,
type_nbytes, count, reducer) == kSuccess,
@ -150,8 +143,9 @@ class AllreduceBase : public IEngine {
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char* _file = _FILE, const int _line = _LINE, const char* _caller = _CALLER) {
void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char *_file = _FILE, const int _line = _LINE,
const char *_caller = _CALLER) override {
if (world_size == 1 || world_size == -1) return;
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
"Broadcast failed");
@ -178,8 +172,8 @@ class AllreduceBase : public IEngine {
*
* \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL) {
int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = nullptr) override {
return 0;
}
/*!
@ -198,8 +192,8 @@ class AllreduceBase : public IEngine {
*
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL) {
void CheckPoint(const Serializable *global_model,
const Serializable *local_model = nullptr) override {
version_number += 1;
}
/*!
@ -222,7 +216,7 @@ class AllreduceBase : public IEngine {
* is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber
*/
virtual void LazyCheckPoint(const Serializable *global_model) {
void LazyCheckPoint(const Serializable *global_model) override {
version_number += 1;
}
/*!
@ -230,7 +224,7 @@ class AllreduceBase : public IEngine {
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
virtual int VersionNumber(void) const {
int VersionNumber() const override {
return version_number;
}
/*!
@ -238,14 +232,14 @@ class AllreduceBase : public IEngine {
* call this function when IEngine throw an exception out,
* this function is only used for test purpose
*/
virtual void InitAfterException(void) {
void InitAfterException() override {
utils::Error("InitAfterException: not implemented");
}
/*!
* \brief report current status to the job tracker
* depending on the job tracker we are in
*/
inline void ReportStatus(void) const {
inline void ReportStatus() const {
if (hadoop_mode != 0) {
fprintf(stderr, "reporter:status:Rabit Phase[%03d] Operation %03d\n",
version_number, seq_counter);
@ -274,7 +268,7 @@ class AllreduceBase : public IEngine {
/*! \brief internal return type */
ReturnTypeEnum value;
// constructor
ReturnType() {}
ReturnType() = default;
ReturnType(ReturnTypeEnum value) : value(value) {} // NOLINT(*)
inline bool operator==(const ReturnTypeEnum &v) const {
return value == v;
@ -306,13 +300,11 @@ class AllreduceBase : public IEngine {
// size of data sent to the link
size_t size_write;
// pointer to buffer head
char *buffer_head;
char *buffer_head {nullptr};
// buffer size, in bytes
size_t buffer_size;
size_t buffer_size {0};
// constructor
LinkRecord(void)
: buffer_head(NULL), buffer_size(0) {
}
LinkRecord() = default;
// initialize buffer
inline void InitBuffer(size_t type_nbytes, size_t count,
size_t reduce_buffer_size) {
@ -328,7 +320,7 @@ class AllreduceBase : public IEngine {
buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
}
// reset the recv and sent size
inline void ResetSize(void) {
inline void ResetSize() {
size_write = size_read = 0;
}
/*!
@ -340,7 +332,7 @@ class AllreduceBase : public IEngine {
* \return the type of reading
*/
inline ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read) {
utils::Assert(buffer_head != NULL, "ReadToRingBuffer: buffer not allocated");
utils::Assert(buffer_head != nullptr, "ReadToRingBuffer: buffer not allocated");
utils::Assert(size_read <= max_size_read, "ReadToRingBuffer: max_size_read check");
size_t ngap = size_read - protect_start;
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
@ -405,7 +397,7 @@ class AllreduceBase : public IEngine {
inline LinkRecord &operator[](size_t i) {
return *plinks[i];
}
inline size_t size(void) const {
inline size_t Size() const {
return plinks.size();
}
};
@ -413,7 +405,7 @@ class AllreduceBase : public IEngine {
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
utils::TCPSocket ConnectTracker(void) const;
utils::TCPSocket ConnectTracker() const;
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
@ -525,64 +517,64 @@ class AllreduceBase : public IEngine {
//---- data structure related to model ----
// call sequence counter, records how many calls we made so far
// from last call to CheckPoint, LoadCheckPoint
int seq_counter;
int seq_counter; // NOLINT
// version number of model
int version_number;
int version_number; // NOLINT
// whether the job is running in hadoop
bool hadoop_mode;
bool hadoop_mode; // NOLINT
//---- local data related to link ----
// index of parent link, can be -1, meaning this is root of the tree
int parent_index;
int parent_index; // NOLINT
// rank of parent node, can be -1
int parent_rank;
int parent_rank; // NOLINT
// sockets of all links this connects to
std::vector<LinkRecord> all_links;
std::vector<LinkRecord> all_links; // NOLINT
// used to record the link where things goes wrong
LinkRecord *err_link;
LinkRecord *err_link; // NOLINT
// all the links in the reduction tree connection
RefLinkVector tree_links;
RefLinkVector tree_links; // NOLINT
// pointer to links in the ring
LinkRecord *ring_prev, *ring_next;
LinkRecord *ring_prev, *ring_next; // NOLINT
//----- meta information-----
// list of enviroment variables that are of possible interest
std::vector<std::string> env_vars;
std::vector<std::string> env_vars; // NOLINT
// unique identifier of the possible job this process is doing
// used to assign ranks, optional, default to NULL
std::string task_id;
std::string task_id; // NOLINT
// uri of current host, to be set by Init
std::string host_uri;
std::string host_uri; // NOLINT
// uri of tracker
std::string tracker_uri;
std::string tracker_uri; // NOLINT
// role in dmlc jobs
std::string dmlc_role;
std::string dmlc_role; // NOLINT
// port of tracker address
int tracker_port;
int tracker_port; // NOLINT
// port of slave process
int slave_port, nport_trial;
int slave_port, nport_trial; // NOLINT
// reduce buffer size
size_t reduce_buffer_size;
size_t reduce_buffer_size; // NOLINT
// reduction method
int reduce_method;
int reduce_method; // NOLINT
// mininum count of cells to use ring based method
size_t reduce_ring_mincount;
size_t reduce_ring_mincount; // NOLINT
// minimul block size per tree reduce
size_t tree_reduce_minsize;
size_t tree_reduce_minsize; // NOLINT
// current rank
int rank;
int rank; // NOLINT
// world size
int world_size;
int world_size; // NOLINT
// connect retry time
int connect_retry;
int connect_retry; // NOLINT
// enable bootstrap cache 0 false 1 true
bool rabit_bootstrap_cache = false;
bool rabit_bootstrap_cache = false; // NOLINT
// enable detailed logging
bool rabit_debug = false;
bool rabit_debug = false; // NOLINT
// by default, if rabit worker not recover in half an hour exit
int timeout_sec = 1800;
int timeout_sec = 1800; // NOLINT
// flag to enable rabit_timeout
bool rabit_timeout = false;
bool rabit_timeout = false; // NOLINT
// Enable TCP node delay
bool rabit_enable_tcp_no_delay = false;
bool rabit_enable_tcp_no_delay = false; // NOLINT
};
} // namespace engine
} // namespace rabit

View File

@ -20,74 +20,65 @@ namespace engine {
class AllreduceMock : public AllreduceRobust {
public:
// constructor
AllreduceMock(void) {
num_trial = 0;
force_local = 0;
report_stats = 0;
tsum_allreduce = 0.0;
tsum_allgather = 0.0;
AllreduceMock() {
num_trial_ = 0;
force_local_ = 0;
report_stats_ = 0;
tsum_allreduce_ = 0.0;
tsum_allgather_ = 0.0;
}
// destructor
virtual ~AllreduceMock(void) {}
virtual void SetParam(const char *name, const char *val) {
~AllreduceMock() override = default;
void SetParam(const char *name, const char *val) override {
AllreduceRobust::SetParam(name, val);
// additional parameters
if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val);
if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial = atoi(val);
if (!strcmp(name, "report_stats")) report_stats = atoi(val);
if (!strcmp(name, "force_local")) force_local = atoi(val);
if (!strcmp(name, "rabit_num_trial")) num_trial_ = atoi(val);
if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial_ = atoi(val);
if (!strcmp(name, "report_stats")) report_stats_ = atoi(val);
if (!strcmp(name, "force_local")) force_local_ = atoi(val);
if (!strcmp(name, "mock")) {
MockKey k;
utils::Check(sscanf(val, "%d,%d,%d,%d",
&k.rank, &k.version, &k.seqno, &k.ntrial) == 4,
"invalid mock parameter");
mock_map[k] = 1;
mock_map_[k] = 1;
}
}
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
ReduceFunction reducer, PreprocFunction prepare_fun,
void *prepare_arg, const char *_file = _FILE,
const int _line = _LINE,
const char *_caller = _CALLER) override {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "AllReduce");
double tstart = utils::GetTime();
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
count, reducer, prepare_fun, prepare_arg,
_file, _line, _caller);
tsum_allreduce += utils::GetTime() - tstart;
tsum_allreduce_ += utils::GetTime() - tstart;
}
virtual void Allgather(void *sendrecvbuf,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Allgather");
void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin,
size_t slice_end, size_t size_prev_slice,
const char *_file = _FILE, const int _line = _LINE,
const char *_caller = _CALLER) override {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Allgather");
double tstart = utils::GetTime();
AllreduceRobust::Allgather(sendrecvbuf, total_size,
slice_begin, slice_end,
size_prev_slice, _file, _line, _caller);
tsum_allgather += utils::GetTime() - tstart;
tsum_allgather_ += utils::GetTime() - tstart;
}
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char *_file = _FILE, const int _line = _LINE,
const char *_caller = _CALLER) override {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Broadcast");
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller);
}
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model) {
tsum_allreduce = 0.0;
tsum_allgather = 0.0;
time_checkpoint = utils::GetTime();
if (force_local == 0) {
int LoadCheckPoint(Serializable *global_model,
Serializable *local_model) override {
tsum_allreduce_ = 0.0;
tsum_allgather_ = 0.0;
time_checkpoint_ = utils::GetTime();
if (force_local_ == 0) {
return AllreduceRobust::LoadCheckPoint(global_model, local_model);
} else {
DummySerializer dum;
@ -95,56 +86,54 @@ class AllreduceMock : public AllreduceRobust {
return AllreduceRobust::LoadCheckPoint(&dum, &com);
}
}
virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint");
void CheckPoint(const Serializable *global_model,
const Serializable *local_model) override {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "CheckPoint");
double tstart = utils::GetTime();
double tbet_chkpt = tstart - time_checkpoint;
if (force_local == 0) {
double tbet_chkpt = tstart - time_checkpoint_;
if (force_local_ == 0) {
AllreduceRobust::CheckPoint(global_model, local_model);
} else {
DummySerializer dum;
ComboSerializer com(global_model, local_model);
AllreduceRobust::CheckPoint(&dum, &com);
}
time_checkpoint = utils::GetTime();
time_checkpoint_ = utils::GetTime();
double tcost = utils::GetTime() - tstart;
if (report_stats != 0 && rank == 0) {
if (report_stats_ != 0 && rank == 0) {
std::stringstream ss;
ss << "[v" << version_number << "] global_size=" << global_checkpoint.length()
<< ",local_size=" << (local_chkpt[0].length() + local_chkpt[1].length())
ss << "[v" << version_number << "] global_size=" << global_checkpoint_.length()
<< ",local_size=" << (local_chkpt_[0].length() + local_chkpt_[1].length())
<< ",check_tcost="<< tcost <<" sec"
<< ",allreduce_tcost=" << tsum_allreduce << " sec"
<< ",allgather_tcost=" << tsum_allgather << " sec"
<< ",allreduce_tcost=" << tsum_allreduce_ << " sec"
<< ",allgather_tcost=" << tsum_allgather_ << " sec"
<< ",between_chpt=" << tbet_chkpt << "sec\n";
this->TrackerPrint(ss.str());
}
tsum_allreduce = 0.0;
tsum_allgather = 0.0;
tsum_allreduce_ = 0.0;
tsum_allgather_ = 0.0;
}
virtual void LazyCheckPoint(const Serializable *global_model) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint");
void LazyCheckPoint(const Serializable *global_model) override {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "LazyCheckPoint");
AllreduceRobust::LazyCheckPoint(global_model);
}
protected:
// force checkpoint to local
int force_local;
int force_local_;
// whether report statistics
int report_stats;
int report_stats_;
// sum of allreduce
double tsum_allreduce;
double tsum_allreduce_;
// sum of allgather
double tsum_allgather;
double time_checkpoint;
double tsum_allgather_;
double time_checkpoint_;
private:
struct DummySerializer : public Serializable {
virtual void Load(Stream *fi) {
}
virtual void Save(Stream *fo) const {
}
void Load(Stream *fi) override {}
void Save(Stream *fo) const override {}
};
struct ComboSerializer : public Serializable {
Serializable *lhs;
@ -155,15 +144,15 @@ class AllreduceMock : public AllreduceRobust {
: lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) {
}
ComboSerializer(const Serializable *lhs, const Serializable *rhs)
: lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) {
: lhs(nullptr), rhs(nullptr), c_lhs(lhs), c_rhs(rhs) {
}
virtual void Load(Stream *fi) {
if (lhs != NULL) lhs->Load(fi);
if (rhs != NULL) rhs->Load(fi);
void Load(Stream *fi) override {
if (lhs != nullptr) lhs->Load(fi);
if (rhs != nullptr) rhs->Load(fi);
}
virtual void Save(Stream *fo) const {
if (c_lhs != NULL) c_lhs->Save(fo);
if (c_rhs != NULL) c_rhs->Save(fo);
void Save(Stream *fo) const override {
if (c_lhs != nullptr) c_lhs->Save(fo);
if (c_rhs != nullptr) c_rhs->Save(fo);
}
};
// key to identify the mock stage
@ -172,7 +161,7 @@ class AllreduceMock : public AllreduceRobust {
int version;
int seqno;
int ntrial;
MockKey(void) {}
MockKey() = default;
MockKey(int rank, int version, int seqno, int ntrial)
: rank(rank), version(version), seqno(seqno), ntrial(ntrial) {}
inline bool operator==(const MockKey &b) const {
@ -189,15 +178,15 @@ class AllreduceMock : public AllreduceRobust {
}
};
// number of failure trials
int num_trial;
int num_trial_;
// record all mock actions
std::map<MockKey, int> mock_map;
std::map<MockKey, int> mock_map_;
// used to generate all kinds of exceptions
inline void Verify(const MockKey &key, const char *name) {
if (mock_map.count(key) != 0) {
num_trial += 1;
if (mock_map_.count(key) != 0) {
num_trial_ += 1;
// data processing frameworks runs on shared process
_error("[%d]@@@Hit Mock Error:%s ", rank, name);
error_("[%d]@@@Hit Mock Error:%s ", rank, name);
}
}
};

View File

@ -40,9 +40,9 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
const std::vector<EdgeType> &edge_in,
size_t out_index)) {
RefLinkVector &links = tree_links;
if (links.size() == 0) return kSuccess;
if (links.Size() == 0) return kSuccess;
// number of links
const int nlink = static_cast<int>(links.size());
const int nlink = static_cast<int>(links.Size());
// initialize the pointers
for (int i = 0; i < nlink; ++i) {
links[i].ResetSize();

File diff suppressed because it is too large Load Diff

View File

@ -22,18 +22,18 @@ namespace engine {
/*! \brief implementation of fault tolerant all reduce engine */
class AllreduceRobust : public AllreduceBase {
public:
AllreduceRobust(void);
virtual ~AllreduceRobust(void) {}
AllreduceRobust();
~AllreduceRobust() override = default;
// initialize the manager
virtual bool Init(int argc, char* argv[]);
bool Init(int argc, char* argv[]) override;
/*! \brief shutdown the engine */
virtual bool Shutdown(void);
bool Shutdown() override;
/*!
* \brief set parameters to the engine
* \param name parameter name
* \param val parameter value
*/
virtual void SetParam(const char *name, const char *val);
void SetParam(const char *name, const char *val) override;
/*!
* \brief perform immutable local bootstrap cache insertion
* \param key unique cache key
@ -67,13 +67,10 @@ class AllreduceRobust : public AllreduceBase {
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Allgather(void *sendrecvbuf_, size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
size_t slice_end, size_t size_prev_slice,
const char *_file = _FILE, const int _line = _LINE,
const char *_caller = _CALLER) override;
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe
@ -90,15 +87,11 @@ class AllreduceRobust : public AllreduceBase {
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
ReduceFunction reducer, PreprocFunction prepare_fun = nullptr,
void *prepare_arg = nullptr, const char *_file = _FILE,
const int _line = _LINE,
const char *_caller = _CALLER) override;
/*!
* \brief broadcast data from root to all nodes
* \param sendrecvbuf_ buffer for both sending and recving data
@ -108,10 +101,9 @@ class AllreduceRobust : public AllreduceBase {
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char *_file = _FILE, const int _line = _LINE,
const char *_caller = _CALLER) override;
/*!
* \brief load latest check point
* \param global_model pointer to the globally shared model/state
@ -134,8 +126,8 @@ class AllreduceRobust : public AllreduceBase {
*
* \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL);
int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = nullptr) override;
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
@ -152,9 +144,9 @@ class AllreduceRobust : public AllreduceBase {
*
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL) {
this->CheckPoint_(global_model, local_model, false);
void CheckPoint(const Serializable *global_model,
const Serializable *local_model = nullptr) override {
this->CheckPointImpl(global_model, local_model, false);
}
/*!
* \brief This function can be used to replace CheckPoint for global_model only,
@ -176,18 +168,20 @@ class AllreduceRobust : public AllreduceBase {
* is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber
*/
virtual void LazyCheckPoint(const Serializable *global_model) {
this->CheckPoint_(global_model, NULL, true);
void LazyCheckPoint(const Serializable *global_model) override {
this->CheckPointImpl(global_model, nullptr, true);
}
/*!
* \brief explicitly re-init everything before calling LoadCheckPoint
* call this function when IEngine throw an exception out,
* this function is only used for test purpose
*/
virtual void InitAfterException(void) {
void InitAfterException() override {
// simple way, shutdown all links
for (size_t i = 0; i < all_links.size(); ++i) {
if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close();
for (auto& link : all_links) {
if (link.sock.BadSocket()) {
link.sock.Close();
}
}
ReConnectLinks("recover");
}
@ -245,69 +239,69 @@ class AllreduceRobust : public AllreduceBase {
// there are nodes request load cache
static const int kLoadBootstrapCache = 16;
// constructor
ActionSummary(void) {}
ActionSummary() = default;
// constructor of action
explicit ActionSummary(int seqno_flag, int cache_flag = 0,
u_int32_t minseqno = kSpecialOp, u_int32_t maxseqno = kSpecialOp) {
seqcode = (minseqno << 5) | seqno_flag;
maxseqcode = (maxseqno << 5) | cache_flag;
seqcode_ = (minseqno << 5) | seqno_flag;
maxseqcode_ = (maxseqno << 5) | cache_flag;
}
// minimum number of all operations by default
// maximum number of all cache operations otherwise
inline u_int32_t seqno(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
inline u_int32_t Seqno(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
return code >> 5;
}
// whether the operation set contains a load_check
inline bool load_check(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
inline bool LoadCheck(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
return (code & kLoadCheck) != 0;
}
// whether the operation set contains a load_cache
inline bool load_cache(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
inline bool LoadCache(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
return (code & kLoadBootstrapCache) != 0;
}
// whether the operation set contains a check point
inline bool check_point(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
inline bool CheckPoint(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
return (code & kCheckPoint) != 0;
}
// whether the operation set contains a check ack
inline bool check_ack(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
inline bool CheckAck(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
return (code & kCheckAck) != 0;
}
// whether the operation set contains different sequence number
inline bool diff_seq() const {
return (seqcode & kDiffSeq) != 0;
inline bool DiffSeq() const {
return (seqcode_ & kDiffSeq) != 0;
}
// returns the operation flag of the result
inline int flag(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
inline int Flag(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
return code & 31;
}
// print flags in user friendly way
inline void print_flags(int rank, std::string prefix ) {
utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n",
rank, prefix.c_str(),
seqno(), check_point(), check_ack(), load_cache(),
diff_seq(), seqno(SeqType::kCache), load_cache(SeqType::kCache));
inline void PrintFlags(int rank, std::string prefix ) {
utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n", rank,
prefix.c_str(), Seqno(), CheckPoint(), CheckAck(),
LoadCache(), DiffSeq(), Seqno(SeqType::kCache),
LoadCache(SeqType::kCache));
}
// reducer for Allreduce, get the result ActionSummary from all nodes
inline static void Reducer(const void *src_, void *dst_,
int len, const MPI::Datatype &dtype) {
const ActionSummary *src = (const ActionSummary*)src_;
const ActionSummary *src = static_cast<const ActionSummary*>(src_);
ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_);
for (int i = 0; i < len; ++i) {
u_int32_t min_seqno = std::min(src[i].seqno(), dst[i].seqno());
u_int32_t max_seqno = std::max(src[i].seqno(SeqType::kCache),
dst[i].seqno(SeqType::kCache));
int action_flag = src[i].flag() | dst[i].flag();
u_int32_t min_seqno = std::min(src[i].Seqno(), dst[i].Seqno());
u_int32_t max_seqno = std::max(src[i].Seqno(SeqType::kCache),
dst[i].Seqno(SeqType::kCache));
int action_flag = src[i].Flag() | dst[i].Flag();
// if any node is not requester set to 0 otherwise 1
int role_flag = src[i].flag(SeqType::kCache) & dst[i].flag(SeqType::kCache);
int role_flag = src[i].Flag(SeqType::kCache) & dst[i].Flag(SeqType::kCache);
// if seqno is different in src and destination
int seq_diff_flag = src[i].seqno() != dst[i].seqno() ? kDiffSeq : 0;
int seq_diff_flag = src[i].Seqno() != dst[i].Seqno() ? kDiffSeq : 0;
// apply or to both seq diff flag as well as cache seq diff flag
dst[i] = ActionSummary(action_flag | seq_diff_flag,
role_flag, min_seqno, max_seqno);
@ -316,19 +310,19 @@ class AllreduceRobust : public AllreduceBase {
private:
// internel sequence code min of rabit seqno
u_int32_t seqcode;
u_int32_t seqcode_;
// internal sequence code max of cache seqno
u_int32_t maxseqcode;
u_int32_t maxseqcode_;
};
/*! \brief data structure to remember result of Bcast and Allreduce calls*/
class ResultBuffer{
public:
// constructor
ResultBuffer(void) {
ResultBuffer() {
this->Clear();
}
// clear the existing record
inline void Clear(void) {
inline void Clear() {
seqno_.clear(); size_.clear();
rptr_.clear(); rptr_.push_back(0);
data_.clear();
@ -358,12 +352,12 @@ class AllreduceRobust : public AllreduceBase {
inline void* Query(int seqid, size_t *p_size) {
size_t idx = std::lower_bound(seqno_.begin(),
seqno_.end(), seqid) - seqno_.begin();
if (idx == seqno_.size() || seqno_[idx] != seqid) return NULL;
if (idx == seqno_.size() || seqno_[idx] != seqid) return nullptr;
*p_size = size_[idx];
return BeginPtr(data_) + rptr_[idx];
}
// drop last stored result
inline void DropLast(void) {
inline void DropLast() {
utils::Assert(seqno_.size() != 0, "there is nothing to be dropped");
seqno_.pop_back();
rptr_.pop_back();
@ -371,7 +365,7 @@ class AllreduceRobust : public AllreduceBase {
data_.resize(rptr_.back());
}
// the sequence number of last stored result
inline int LastSeqNo(void) const {
inline int LastSeqNo() const {
if (seqno_.size() == 0) return -1;
return seqno_.back();
}
@ -407,9 +401,8 @@ class AllreduceRobust : public AllreduceBase {
*
* \sa CheckPoint, LazyCheckPoint
*/
void CheckPoint_(const Serializable *global_model,
const Serializable *local_model,
bool lazy_checkpt);
void CheckPointImpl(const Serializable *global_model,
const Serializable *local_model, bool lazy_checkpt);
/*!
* \brief reset the all the existing links by sending Out-of-Band message marker
* after this function finishes, all the messages received and sent
@ -423,7 +416,7 @@ class AllreduceRobust : public AllreduceBase {
* when kSockError is returned, it simply means there are bad sockets in the links,
* and some link recovery proceduer is needed
*/
ReturnType TryResetLinks(void);
ReturnType TryResetLinks();
/*!
* \brief if err_type indicates an error
* recover links according to the error type reported
@ -619,29 +612,29 @@ o * the input state must exactly one saved state(local state of current node)
size_t out_index));
//---- recovery data structure ----
// the round of result buffer, used to mode the result
int result_buffer_round;
int result_buffer_round_;
// result buffer of all reduce
ResultBuffer resbuf;
ResultBuffer resbuf_;
// current cached allreduce/braodcast sequence number
int cur_cache_seq;
int cur_cache_seq_;
// result buffer of cached all reduce
ResultBuffer cachebuf;
ResultBuffer cachebuf_;
// key of each cache entry
ResultBuffer lookupbuf;
ResultBuffer lookupbuf_;
// last check point global model
std::string global_checkpoint;
std::string global_checkpoint_;
// lazy checkpoint of global model
const Serializable *global_lazycheck;
const Serializable *global_lazycheck_;
// number of replica for local state/model
int num_local_replica;
int num_local_replica_;
// number of default local replica
int default_local_replica;
int default_local_replica_;
// flag to decide whether local model is used, -1: unknown, 0: no, 1:yes
int use_local_model;
int use_local_model_;
// number of replica for global state/model
int num_global_replica;
int num_global_replica_;
// number of times recovery happens
int recover_counter;
int recover_counter_;
// --- recovery data structure for local checkpoint
// there is two version of the data structure,
// at one time one version is valid and another is used as temp memory
@ -649,21 +642,21 @@ o * the input state must exactly one saved state(local state of current node)
// local model is stored in CSR format(like a sparse matrices)
// local_model[rptr[0]:rptr[1]] stores the model of current node
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops
std::vector<size_t> local_rptr[2];
std::vector<size_t> local_rptr_[2];
// storage for local model replicas
std::string local_chkpt[2];
std::string local_chkpt_[2];
// version of local checkpoint can be 1 or 0
int local_chkpt_version;
int local_chkpt_version_;
// if checkpoint were loaded, used to distinguish results boostrap cache from seqno cache
bool checkpoint_loaded;
bool checkpoint_loaded_;
// sidecar executing timeout task
std::future<bool> rabit_timeout_task;
std::future<bool> rabit_timeout_task_;
// flag to shutdown rabit_timeout_task before timeout
std::atomic<bool> shutdown_timeout{false};
std::atomic<bool> shutdown_timeout_{false};
// error handler
void (* _error)(const char *fmt, ...) = utils::Error;
void (* error_)(const char *fmt, ...) = utils::Error;
// assert handler
void (* _assert)(bool exp, const char *fmt, ...) = utils::Assert;
void (* assert_)(bool exp, const char *fmt, ...) = utils::Assert;
};
} // namespace engine
} // namespace rabit

View File

@ -33,12 +33,12 @@ struct FHelper<op::BitOR, DType> {
};
template<typename OP>
void Allreduce_(void *sendrecvbuf_,
void Allreduce(void *sendrecvbuf_,
size_t count,
engine::mpi::DataType enum_dtype,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
using namespace engine::mpi;
using namespace engine::mpi; // NOLINT
switch (enum_dtype) {
case kChar:
rabit::Allreduce<OP>
@ -89,28 +89,28 @@ void Allreduce(void *sendrecvbuf,
engine::mpi::OpType enum_op,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
using namespace engine::mpi;
using namespace engine::mpi; // NOLINT
switch (enum_op) {
case kMax:
Allreduce_<op::Max>
Allreduce<op::Max>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
return;
case kMin:
Allreduce_<op::Min>
Allreduce<op::Min>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
return;
case kSum:
Allreduce_<op::Sum>
Allreduce<op::Sum>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
return;
case kBitwiseOR:
Allreduce_<op::BitOR>
Allreduce<op::BitOR>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
@ -124,7 +124,7 @@ void Allgather(void *sendrecvbuf_,
size_t size_node_slice,
size_t size_prev_slice,
int enum_dtype) {
using namespace engine::mpi;
using namespace engine::mpi; // NOLINT
size_t type_size = 0;
switch (enum_dtype) {
case kChar:
@ -184,7 +184,7 @@ struct ReadWrapper : public Serializable {
std::string *p_str;
explicit ReadWrapper(std::string *p_str)
: p_str(p_str) {}
virtual void Load(Stream *fi) {
void Load(Stream *fi) override {
uint64_t sz;
utils::Assert(fi->Read(&sz, sizeof(sz)) != 0,
"Read pickle string");
@ -194,7 +194,7 @@ struct ReadWrapper : public Serializable {
"Read pickle string");
}
}
virtual void Save(Stream *fo) const {
void Save(Stream *fo) const override {
utils::Error("not implemented");
}
};
@ -206,10 +206,10 @@ struct WriteWrapper : public Serializable {
size_t length)
: data(data), length(length) {
}
virtual void Load(Stream *fi) {
void Load(Stream *fi) override {
utils::Error("not implemented");
}
virtual void Save(Stream *fo) const {
void Save(Stream *fo) const override {
uint64_t sz = static_cast<uint16_t>(length);
fo->Write(&sz, sizeof(sz));
fo->Write(data, length * sizeof(char));
@ -298,8 +298,8 @@ RABIT_DLL int RabitLoadCheckPoint(char **out_global_model,
ReadWrapper sl(&local_buffer);
int version;
if (out_local_model == NULL) {
version = rabit::LoadCheckPoint(&sg, NULL);
if (out_local_model == nullptr) {
version = rabit::LoadCheckPoint(&sg, nullptr);
*out_global_model = BeginPtr(global_buffer);
*out_global_len = static_cast<rbt_ulong>(global_buffer.length());
} else {
@ -317,8 +317,8 @@ RABIT_DLL void RabitCheckPoint(const char *global_model, rbt_ulong global_len,
using namespace rabit::c_api; // NOLINT(*)
WriteWrapper sg(global_model, global_len);
WriteWrapper sl(local_model, local_len);
if (local_model == NULL) {
rabit::CheckPoint(&sg, NULL);
if (local_model == nullptr) {
rabit::CheckPoint(&sg, nullptr);
} else {
rabit::CheckPoint(&sg, &sl);
}

View File

@ -7,18 +7,19 @@
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#include <rabit/base.h>
#include <dmlc/thread_local.h>
#include <memory>
#include "rabit/internal/engine.h"
#include "allreduce_base.h"
#include "allreduce_robust.h"
#include "rabit/internal/thread_local.h"
namespace rabit {
namespace engine {
// singleton sync manager
#ifndef RABIT_USE_BASE
#ifndef RABIT_USE_MOCK
typedef AllreduceRobust Manager;
using Manager = AllreduceRobust;
#else
typedef AllreduceMock Manager;
#endif // RABIT_USE_MOCK
@ -31,13 +32,13 @@ struct ThreadLocalEntry {
/*! \brief stores the current engine */
std::unique_ptr<Manager> engine;
/*! \brief whether init has been called */
bool initialized;
bool initialized{false};
/*! \brief constructor */
ThreadLocalEntry() : initialized(false) {}
ThreadLocalEntry() = default;
};
// define the threadlocal store.
typedef ThreadLocalStore<ThreadLocalEntry> EngineThreadLocal;
using EngineThreadLocal = dmlc::ThreadLocalStore<ThreadLocalEntry>;
/*! \brief intiialize the synchronization module */
bool Init(int argc, char *argv[]) {
@ -95,7 +96,7 @@ void Allgather(void *sendrecvbuf_, size_t total_size,
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf,
void Allreduce_(void *sendrecvbuf, // NOLINT
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
@ -111,18 +112,15 @@ void Allreduce_(void *sendrecvbuf,
}
// code for reduce handle
ReduceHandle::ReduceHandle(void)
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
}
ReduceHandle::~ReduceHandle(void) {}
ReduceHandle::ReduceHandle() = default;
ReduceHandle::~ReduceHandle() = default;
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return static_cast<int>(dtype.type_size);
}
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
utils::Assert(redfunc_ == NULL, "cannot initialize reduce handle twice");
utils::Assert(redfunc_ == nullptr, "cannot initialize reduce handle twice");
redfunc_ = redfunc;
}
@ -133,7 +131,7 @@ void ReduceHandle::Allreduce(void *sendrecvbuf,
const char* _file,
const int _line,
const char* _caller) {
utils::Assert(redfunc_ != NULL, "must intialize handle to call AllReduce");
utils::Assert(redfunc_ != nullptr, "must intialize handle to call AllReduce");
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
redfunc_, prepare_fun, prepare_arg,
_file, _line, _caller);

View File

@ -16,78 +16,68 @@ namespace engine {
/*! \brief EmptyEngine */
class EmptyEngine : public IEngine {
public:
EmptyEngine(void) {
version_number = 0;
EmptyEngine() {
version_number_ = 0;
}
virtual void Allgather(void *sendrecvbuf_,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file,
const int _line,
const char* _caller) {
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
size_t slice_end, size_t size_prev_slice, const char *_file,
const int _line, const char *_caller) override {
utils::Error("EmptyEngine:: Allgather is not supported");
}
virtual int GetRingPrevRank(void) const {
int GetRingPrevRank() const override {
utils::Error("EmptyEngine:: GetRingPrevRank is not supported");
return -1;
}
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg,
const char* _file,
const int _line,
const char* _caller) {
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
ReduceFunction reducer, PreprocFunction prepare_fun,
void *prepare_arg, const char *_file, const int _line,
const char *_caller) override {
utils::Error("EmptyEngine:: Allreduce is not supported,"\
"use Allreduce_ instead");
}
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root,
const char* _file, const int _line, const char* _caller) {
void Broadcast(void *sendrecvbuf_, size_t size, int root,
const char* _file, const int _line, const char* _caller) override {
}
virtual void InitAfterException(void) {
void InitAfterException() override {
utils::Error("EmptyEngine is not fault tolerant");
}
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL) {
int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = nullptr) override {
return 0;
}
virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL) {
version_number += 1;
void CheckPoint(const Serializable *global_model,
const Serializable *local_model = nullptr) override {
version_number_ += 1;
}
virtual void LazyCheckPoint(const Serializable *global_model) {
version_number += 1;
void LazyCheckPoint(const Serializable *global_model) override {
version_number_ += 1;
}
virtual int VersionNumber(void) const {
return version_number;
int VersionNumber() const override {
return version_number_;
}
/*! \brief get rank of current node */
virtual int GetRank(void) const {
int GetRank() const override {
return 0;
}
/*! \brief get total number of */
virtual int GetWorldSize(void) const {
int GetWorldSize() const override {
return 1;
}
/*! \brief whether it is distributed */
virtual bool IsDistributed(void) const {
bool IsDistributed() const override {
return false;
}
/*! \brief get the host name of current node */
virtual std::string GetHost(void) const {
std::string GetHost() const override {
return std::string("");
}
virtual void TrackerPrint(const std::string &msg) {
void TrackerPrint(const std::string &msg) override {
// simply print information into the tracker
utils::Printf("%s", msg.c_str());
}
private:
int version_number;
int version_number_;
};
// singleton sync manager
@ -98,12 +88,12 @@ bool Init(int argc, char *argv[]) {
return true;
}
/*! \brief finalize syncrhonization module */
bool Finalize(void) {
bool Finalize() {
return true;
}
/*! \brief singleton method to get engine */
IEngine *GetEngine(void) {
IEngine *GetEngine() {
return &manager;
}
// perform in-place allreduce, on sendrecvbuf
@ -118,13 +108,12 @@ void Allreduce_(void *sendrecvbuf,
const char* _file,
const int _line,
const char* _caller) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
}
// code for reduce handle
ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) {
}
ReduceHandle::~ReduceHandle(void) {}
ReduceHandle::ReduceHandle() = default;
ReduceHandle::~ReduceHandle() = default;
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return 0;
@ -137,7 +126,7 @@ void ReduceHandle::Allreduce(void *sendrecvbuf,
const char* _file,
const int _line,
const char* _caller) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
}
} // namespace engine
} // namespace rabit

View File

@ -12,4 +12,3 @@
#include <rabit/base.h>
#include "allreduce_mock.h"
#include "engine.cc"