Correct style warnings from clang-tidy for rabit. (#6095)
This commit is contained in:
parent
da61d9460b
commit
b0001a6e29
@ -19,7 +19,7 @@
|
|||||||
#define _CALLER "N/A"
|
#define _CALLER "N/A"
|
||||||
#endif // (defined(__GNUC__) && !defined(__clang__))
|
#endif // (defined(__GNUC__) && !defined(__clang__))
|
||||||
|
|
||||||
namespace MPI {
|
namespace MPI { // NOLINT
|
||||||
/*! \brief MPI data type just to be compatible with MPI reduce function*/
|
/*! \brief MPI data type just to be compatible with MPI reduce function*/
|
||||||
class Datatype;
|
class Datatype;
|
||||||
}
|
}
|
||||||
@ -36,7 +36,7 @@ class IEngine {
|
|||||||
* used to prepare the data used by AllReduce
|
* used to prepare the data used by AllReduce
|
||||||
* \param arg additional possible argument used to invoke the preprocessor
|
* \param arg additional possible argument used to invoke the preprocessor
|
||||||
*/
|
*/
|
||||||
typedef void (PreprocFunction) (void *arg);
|
typedef void (PreprocFunction) (void *arg); // NOLINT
|
||||||
/*!
|
/*!
|
||||||
* \brief reduce function, the same form of MPI reduce function is used,
|
* \brief reduce function, the same form of MPI reduce function is used,
|
||||||
* to be compatible with MPI interface
|
* to be compatible with MPI interface
|
||||||
@ -48,11 +48,11 @@ class IEngine {
|
|||||||
* the definition of the reduce function should be type aware
|
* the definition of the reduce function should be type aware
|
||||||
* \param dtype the data type object, to be compatible with MPI reduce
|
* \param dtype the data type object, to be compatible with MPI reduce
|
||||||
*/
|
*/
|
||||||
typedef void (ReduceFunction) (const void *src,
|
typedef void (ReduceFunction) (const void *src, // NOLINT
|
||||||
void *dst, int count,
|
void *dst, int count,
|
||||||
const MPI::Datatype &dtype);
|
const MPI::Datatype &dtype);
|
||||||
/*! \brief virtual destructor */
|
/*! \brief virtual destructor */
|
||||||
virtual ~IEngine() {}
|
~IEngine() = default;
|
||||||
/*!
|
/*!
|
||||||
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||||
* the data provided by current node k is [slice_begin, slice_end),
|
* the data provided by current node k is [slice_begin, slice_end),
|
||||||
@ -96,8 +96,8 @@ class IEngine {
|
|||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
ReduceFunction reducer,
|
ReduceFunction reducer,
|
||||||
PreprocFunction prepare_fun = NULL,
|
PreprocFunction prepare_fun = nullptr,
|
||||||
void *prepare_arg = NULL,
|
void *prepare_arg = nullptr,
|
||||||
const char* _file = _FILE,
|
const char* _file = _FILE,
|
||||||
const int _line = _LINE,
|
const int _line = _LINE,
|
||||||
const char* _caller = _CALLER) = 0;
|
const char* _caller = _CALLER) = 0;
|
||||||
@ -119,7 +119,7 @@ class IEngine {
|
|||||||
* call this function when IEngine throws an exception,
|
* call this function when IEngine throws an exception,
|
||||||
* this function should only be used for test purposes
|
* this function should only be used for test purposes
|
||||||
*/
|
*/
|
||||||
virtual void InitAfterException(void) = 0;
|
virtual void InitAfterException() = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief loads the latest check point
|
* \brief loads the latest check point
|
||||||
* \param global_model pointer to the globally shared model/state
|
* \param global_model pointer to the globally shared model/state
|
||||||
@ -143,7 +143,7 @@ class IEngine {
|
|||||||
* \sa CheckPoint, VersionNumber
|
* \sa CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual int LoadCheckPoint(Serializable *global_model,
|
virtual int LoadCheckPoint(Serializable *global_model,
|
||||||
Serializable *local_model = NULL) = 0;
|
Serializable *local_model = nullptr) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief checkpoints the model, meaning a stage of execution was finished
|
* \brief checkpoints the model, meaning a stage of execution was finished
|
||||||
* every time we call check point, a version number increases by ones
|
* every time we call check point, a version number increases by ones
|
||||||
@ -161,7 +161,7 @@ class IEngine {
|
|||||||
* \sa LoadCheckPoint, VersionNumber
|
* \sa LoadCheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual void CheckPoint(const Serializable *global_model,
|
virtual void CheckPoint(const Serializable *global_model,
|
||||||
const Serializable *local_model = NULL) = 0;
|
const Serializable *local_model = nullptr) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief This function can be used to replace CheckPoint for global_model only,
|
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||||
* when certain condition is met (see detailed explanation).
|
* when certain condition is met (see detailed explanation).
|
||||||
@ -188,17 +188,17 @@ class IEngine {
|
|||||||
* which means how many calls to CheckPoint we made so far
|
* which means how many calls to CheckPoint we made so far
|
||||||
* \sa LoadCheckPoint, CheckPoint
|
* \sa LoadCheckPoint, CheckPoint
|
||||||
*/
|
*/
|
||||||
virtual int VersionNumber(void) const = 0;
|
virtual int VersionNumber() const = 0;
|
||||||
/*! \brief gets rank of previous node in ring topology */
|
/*! \brief gets rank of previous node in ring topology */
|
||||||
virtual int GetRingPrevRank(void) const = 0;
|
virtual int GetRingPrevRank() const = 0;
|
||||||
/*! \brief gets rank of current node */
|
/*! \brief gets rank of current node */
|
||||||
virtual int GetRank(void) const = 0;
|
virtual int GetRank() const = 0;
|
||||||
/*! \brief gets total number of nodes */
|
/*! \brief gets total number of nodes */
|
||||||
virtual int GetWorldSize(void) const = 0;
|
virtual int GetWorldSize() const = 0;
|
||||||
/*! \brief whether we run in distribted mode */
|
/*! \brief whether we run in distribted mode */
|
||||||
virtual bool IsDistributed(void) const = 0;
|
virtual bool IsDistributed() const = 0;
|
||||||
/*! \brief gets the host name of the current node */
|
/*! \brief gets the host name of the current node */
|
||||||
virtual std::string GetHost(void) const = 0;
|
virtual std::string GetHost() const = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief prints the msg in the tracker,
|
* \brief prints the msg in the tracker,
|
||||||
* this function can be used to communicate progress information to
|
* this function can be used to communicate progress information to
|
||||||
@ -211,9 +211,9 @@ class IEngine {
|
|||||||
/*! \brief initializes the engine module */
|
/*! \brief initializes the engine module */
|
||||||
bool Init(int argc, char *argv[]);
|
bool Init(int argc, char *argv[]);
|
||||||
/*! \brief finalizes the engine module */
|
/*! \brief finalizes the engine module */
|
||||||
bool Finalize(void);
|
bool Finalize();
|
||||||
/*! \brief singleton method to get engine */
|
/*! \brief singleton method to get engine */
|
||||||
IEngine *GetEngine(void);
|
IEngine *GetEngine();
|
||||||
|
|
||||||
/*! \brief namespace that contains stubs to be compatible with MPI */
|
/*! \brief namespace that contains stubs to be compatible with MPI */
|
||||||
namespace mpi {
|
namespace mpi {
|
||||||
@ -280,14 +280,14 @@ void Allgather(void* sendrecvbuf,
|
|||||||
* \param _line caller line number used to generate unique cache key
|
* \param _line caller line number used to generate unique cache key
|
||||||
* \param _caller caller function name used to generate unique cache key
|
* \param _caller caller function name used to generate unique cache key
|
||||||
*/
|
*/
|
||||||
void Allreduce_(void *sendrecvbuf,
|
void Allreduce_(void *sendrecvbuf, // NOLINT
|
||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
IEngine::ReduceFunction red,
|
IEngine::ReduceFunction red,
|
||||||
mpi::DataType dtype,
|
mpi::DataType dtype,
|
||||||
mpi::OpType op,
|
mpi::OpType op,
|
||||||
IEngine::PreprocFunction prepare_fun = NULL,
|
IEngine::PreprocFunction prepare_fun = nullptr,
|
||||||
void *prepare_arg = NULL,
|
void *prepare_arg = nullptr,
|
||||||
const char* _file = _FILE,
|
const char* _file = _FILE,
|
||||||
const int _line = _LINE,
|
const int _line = _LINE,
|
||||||
const char* _caller = _CALLER);
|
const char* _caller = _CALLER);
|
||||||
@ -298,9 +298,9 @@ void Allreduce_(void *sendrecvbuf,
|
|||||||
class ReduceHandle {
|
class ReduceHandle {
|
||||||
public:
|
public:
|
||||||
// constructor
|
// constructor
|
||||||
ReduceHandle(void);
|
ReduceHandle();
|
||||||
// destructor
|
// destructor
|
||||||
~ReduceHandle(void);
|
~ReduceHandle();
|
||||||
/*!
|
/*!
|
||||||
* \brief initialize the reduce function,
|
* \brief initialize the reduce function,
|
||||||
* with the type the reduce function needs to deal with
|
* with the type the reduce function needs to deal with
|
||||||
@ -323,8 +323,8 @@ class ReduceHandle {
|
|||||||
void Allreduce(void *sendrecvbuf,
|
void Allreduce(void *sendrecvbuf,
|
||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
IEngine::PreprocFunction prepare_fun = NULL,
|
IEngine::PreprocFunction prepare_fun = nullptr,
|
||||||
void *prepare_arg = NULL,
|
void *prepare_arg = nullptr,
|
||||||
const char* _file = _FILE,
|
const char* _file = _FILE,
|
||||||
const int _line = _LINE,
|
const int _line = _LINE,
|
||||||
const char* _caller = _CALLER);
|
const char* _caller = _CALLER);
|
||||||
@ -333,11 +333,11 @@ class ReduceHandle {
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
// handle function field
|
// handle function field
|
||||||
void *handle_;
|
void *handle_ {nullptr};
|
||||||
// reduce function of the reducer
|
// reduce function of the reducer
|
||||||
IEngine::ReduceFunction *redfunc_;
|
IEngine::ReduceFunction *redfunc_{nullptr};
|
||||||
// handle to the type field
|
// handle to the type field
|
||||||
void *htype_;
|
void *htype_{nullptr};
|
||||||
// the created type in 4 bytes
|
// the created type in 4 bytes
|
||||||
size_t created_type_nbytes_;
|
size_t created_type_nbytes_;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -19,12 +19,12 @@
|
|||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace utils {
|
namespace utils {
|
||||||
/*! \brief re-use definition of dmlc::SeekStream */
|
/*! \brief re-use definition of dmlc::SeekStream */
|
||||||
typedef dmlc::SeekStream SeekStream;
|
using SeekStream = dmlc::SeekStream;
|
||||||
/*! \brief fixed size memory buffer */
|
/*! \brief fixed size memory buffer */
|
||||||
struct MemoryFixSizeBuffer : public SeekStream {
|
struct MemoryFixSizeBuffer : public SeekStream {
|
||||||
public:
|
public:
|
||||||
// similar to SEEK_END in libc
|
// similar to SEEK_END in libc
|
||||||
static size_t constexpr SeekEnd = std::numeric_limits<size_t>::max();
|
static size_t constexpr kSeekEnd = std::numeric_limits<size_t>::max();
|
||||||
|
|
||||||
public:
|
public:
|
||||||
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
|
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
|
||||||
@ -32,31 +32,31 @@ struct MemoryFixSizeBuffer : public SeekStream {
|
|||||||
buffer_size_(buffer_size) {
|
buffer_size_(buffer_size) {
|
||||||
curr_ptr_ = 0;
|
curr_ptr_ = 0;
|
||||||
}
|
}
|
||||||
virtual ~MemoryFixSizeBuffer(void) {}
|
~MemoryFixSizeBuffer() override = default;
|
||||||
virtual size_t Read(void *ptr, size_t size) {
|
size_t Read(void *ptr, size_t size) override {
|
||||||
size_t nread = std::min(buffer_size_ - curr_ptr_, size);
|
size_t nread = std::min(buffer_size_ - curr_ptr_, size);
|
||||||
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
|
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
|
||||||
curr_ptr_ += nread;
|
curr_ptr_ += nread;
|
||||||
return nread;
|
return nread;
|
||||||
}
|
}
|
||||||
virtual void Write(const void *ptr, size_t size) {
|
void Write(const void *ptr, size_t size) override {
|
||||||
if (size == 0) return;
|
if (size == 0) return;
|
||||||
utils::Assert(curr_ptr_ + size <= buffer_size_,
|
utils::Assert(curr_ptr_ + size <= buffer_size_,
|
||||||
"write position exceed fixed buffer size");
|
"write position exceed fixed buffer size");
|
||||||
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
|
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
|
||||||
curr_ptr_ += size;
|
curr_ptr_ += size;
|
||||||
}
|
}
|
||||||
virtual void Seek(size_t pos) {
|
void Seek(size_t pos) override {
|
||||||
if (pos == SeekEnd) {
|
if (pos == kSeekEnd) {
|
||||||
curr_ptr_ = buffer_size_;
|
curr_ptr_ = buffer_size_;
|
||||||
} else {
|
} else {
|
||||||
curr_ptr_ = static_cast<size_t>(pos);
|
curr_ptr_ = static_cast<size_t>(pos);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
virtual size_t Tell(void) {
|
size_t Tell() override {
|
||||||
return curr_ptr_;
|
return curr_ptr_;
|
||||||
}
|
}
|
||||||
virtual bool AtEnd(void) const {
|
virtual bool AtEnd() const {
|
||||||
return curr_ptr_ == buffer_size_;
|
return curr_ptr_ == buffer_size_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -76,8 +76,8 @@ struct MemoryBufferStream : public SeekStream {
|
|||||||
: p_buffer_(p_buffer) {
|
: p_buffer_(p_buffer) {
|
||||||
curr_ptr_ = 0;
|
curr_ptr_ = 0;
|
||||||
}
|
}
|
||||||
virtual ~MemoryBufferStream(void) {}
|
~MemoryBufferStream() override = default;
|
||||||
virtual size_t Read(void *ptr, size_t size) {
|
size_t Read(void *ptr, size_t size) override {
|
||||||
utils::Assert(curr_ptr_ <= p_buffer_->length(),
|
utils::Assert(curr_ptr_ <= p_buffer_->length(),
|
||||||
"read can not have position excceed buffer length");
|
"read can not have position excceed buffer length");
|
||||||
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
|
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
|
||||||
@ -85,7 +85,7 @@ struct MemoryBufferStream : public SeekStream {
|
|||||||
curr_ptr_ += nread;
|
curr_ptr_ += nread;
|
||||||
return nread;
|
return nread;
|
||||||
}
|
}
|
||||||
virtual void Write(const void *ptr, size_t size) {
|
void Write(const void *ptr, size_t size) override {
|
||||||
if (size == 0) return;
|
if (size == 0) return;
|
||||||
if (curr_ptr_ + size > p_buffer_->length()) {
|
if (curr_ptr_ + size > p_buffer_->length()) {
|
||||||
p_buffer_->resize(curr_ptr_+size);
|
p_buffer_->resize(curr_ptr_+size);
|
||||||
@ -93,13 +93,13 @@ struct MemoryBufferStream : public SeekStream {
|
|||||||
std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
|
std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
|
||||||
curr_ptr_ += size;
|
curr_ptr_ += size;
|
||||||
}
|
}
|
||||||
virtual void Seek(size_t pos) {
|
void Seek(size_t pos) override {
|
||||||
curr_ptr_ = static_cast<size_t>(pos);
|
curr_ptr_ = static_cast<size_t>(pos);
|
||||||
}
|
}
|
||||||
virtual size_t Tell(void) {
|
size_t Tell() override {
|
||||||
return curr_ptr_;
|
return curr_ptr_;
|
||||||
}
|
}
|
||||||
virtual bool AtEnd(void) const {
|
virtual bool AtEnd() const {
|
||||||
return curr_ptr_ == p_buffer_->length();
|
return curr_ptr_ == p_buffer_->length();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -19,45 +19,45 @@ namespace engine {
|
|||||||
namespace mpi {
|
namespace mpi {
|
||||||
// template function to translate type to enum indicator
|
// template function to translate type to enum indicator
|
||||||
template<typename DType>
|
template<typename DType>
|
||||||
inline DataType GetType(void);
|
inline DataType GetType();
|
||||||
template<>
|
template<>
|
||||||
inline DataType GetType<char>(void) {
|
inline DataType GetType<char>() {
|
||||||
return kChar;
|
return kChar;
|
||||||
}
|
}
|
||||||
template<>
|
template<>
|
||||||
inline DataType GetType<unsigned char>(void) {
|
inline DataType GetType<unsigned char>() {
|
||||||
return kUChar;
|
return kUChar;
|
||||||
}
|
}
|
||||||
template<>
|
template<>
|
||||||
inline DataType GetType<int>(void) {
|
inline DataType GetType<int>() {
|
||||||
return kInt;
|
return kInt;
|
||||||
}
|
}
|
||||||
template<>
|
template<>
|
||||||
inline DataType GetType<unsigned int>(void) { // NOLINT(*)
|
inline DataType GetType<unsigned int>() { // NOLINT(*)
|
||||||
return kUInt;
|
return kUInt;
|
||||||
}
|
}
|
||||||
template<>
|
template<>
|
||||||
inline DataType GetType<long>(void) { // NOLINT(*)
|
inline DataType GetType<long>() { // NOLINT(*)
|
||||||
return kLong;
|
return kLong;
|
||||||
}
|
}
|
||||||
template<>
|
template<>
|
||||||
inline DataType GetType<unsigned long>(void) { // NOLINT(*)
|
inline DataType GetType<unsigned long>() { // NOLINT(*)
|
||||||
return kULong;
|
return kULong;
|
||||||
}
|
}
|
||||||
template<>
|
template<>
|
||||||
inline DataType GetType<float>(void) {
|
inline DataType GetType<float>() {
|
||||||
return kFloat;
|
return kFloat;
|
||||||
}
|
}
|
||||||
template<>
|
template<>
|
||||||
inline DataType GetType<double>(void) {
|
inline DataType GetType<double>() {
|
||||||
return kDouble;
|
return kDouble;
|
||||||
}
|
}
|
||||||
template<>
|
template<>
|
||||||
inline DataType GetType<long long>(void) { // NOLINT(*)
|
inline DataType GetType<long long>() { // NOLINT(*)
|
||||||
return kLongLong;
|
return kLongLong;
|
||||||
}
|
}
|
||||||
template<>
|
template<>
|
||||||
inline DataType GetType<unsigned long long>(void) { // NOLINT(*)
|
inline DataType GetType<unsigned long long>() { // NOLINT(*)
|
||||||
return kULongLong;
|
return kULongLong;
|
||||||
}
|
}
|
||||||
} // namespace mpi
|
} // namespace mpi
|
||||||
@ -94,7 +94,7 @@ struct BitOR {
|
|||||||
};
|
};
|
||||||
template<typename OP, typename DType>
|
template<typename OP, typename DType>
|
||||||
inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
|
inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
|
||||||
const DType* src = (const DType*)src_;
|
const DType* src = static_cast<const DType*>(src_);
|
||||||
DType* dst = (DType*)dst_; // NOLINT(*)
|
DType* dst = (DType*)dst_; // NOLINT(*)
|
||||||
for (int i = 0; i < len; i++) {
|
for (int i = 0; i < len; i++) {
|
||||||
OP::Reduce(dst[i], src[i]);
|
OP::Reduce(dst[i], src[i]);
|
||||||
@ -107,27 +107,27 @@ inline bool Init(int argc, char *argv[]) {
|
|||||||
return engine::Init(argc, argv);
|
return engine::Init(argc, argv);
|
||||||
}
|
}
|
||||||
// finalize the rabit engine
|
// finalize the rabit engine
|
||||||
inline bool Finalize(void) {
|
inline bool Finalize() {
|
||||||
return engine::Finalize();
|
return engine::Finalize();
|
||||||
}
|
}
|
||||||
// get the rank of the previous worker in ring topology
|
// get the rank of the previous worker in ring topology
|
||||||
inline int GetRingPrevRank(void) {
|
inline int GetRingPrevRank() {
|
||||||
return engine::GetEngine()->GetRingPrevRank();
|
return engine::GetEngine()->GetRingPrevRank();
|
||||||
}
|
}
|
||||||
// get the rank of current process
|
// get the rank of current process
|
||||||
inline int GetRank(void) {
|
inline int GetRank() {
|
||||||
return engine::GetEngine()->GetRank();
|
return engine::GetEngine()->GetRank();
|
||||||
}
|
}
|
||||||
// the the size of the world
|
// the the size of the world
|
||||||
inline int GetWorldSize(void) {
|
inline int GetWorldSize() {
|
||||||
return engine::GetEngine()->GetWorldSize();
|
return engine::GetEngine()->GetWorldSize();
|
||||||
}
|
}
|
||||||
// whether rabit is distributed
|
// whether rabit is distributed
|
||||||
inline bool IsDistributed(void) {
|
inline bool IsDistributed() {
|
||||||
return engine::GetEngine()->IsDistributed();
|
return engine::GetEngine()->IsDistributed();
|
||||||
}
|
}
|
||||||
// get the name of current processor
|
// get the name of current processor
|
||||||
inline std::string GetProcessorName(void) {
|
inline std::string GetProcessorName() {
|
||||||
return engine::GetEngine()->GetHost();
|
return engine::GetEngine()->GetHost();
|
||||||
}
|
}
|
||||||
// broadcast data to all other nodes from root
|
// broadcast data to all other nodes from root
|
||||||
@ -183,7 +183,7 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
|
|||||||
|
|
||||||
// C++11 support for lambda prepare function
|
// C++11 support for lambda prepare function
|
||||||
#if DMLC_USE_CXX11
|
#if DMLC_USE_CXX11
|
||||||
inline void InvokeLambda_(void *fun) {
|
inline void InvokeLambda(void *fun) {
|
||||||
(*static_cast<std::function<void()>*>(fun))();
|
(*static_cast<std::function<void()>*>(fun))();
|
||||||
}
|
}
|
||||||
template<typename OP, typename DType>
|
template<typename OP, typename DType>
|
||||||
@ -193,7 +193,7 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
|
|||||||
const int _line,
|
const int _line,
|
||||||
const char* _caller) {
|
const char* _caller) {
|
||||||
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
|
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
|
||||||
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda_, &prepare_fun,
|
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda, &prepare_fun,
|
||||||
_file, _line, _caller);
|
_file, _line, _caller);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -245,7 +245,7 @@ inline void LazyCheckPoint(const Serializable *global_model) {
|
|||||||
engine::GetEngine()->LazyCheckPoint(global_model);
|
engine::GetEngine()->LazyCheckPoint(global_model);
|
||||||
}
|
}
|
||||||
// return the version number of currently stored model
|
// return the version number of currently stored model
|
||||||
inline int VersionNumber(void) {
|
inline int VersionNumber() {
|
||||||
return engine::GetEngine()->VersionNumber();
|
return engine::GetEngine()->VersionNumber();
|
||||||
}
|
}
|
||||||
// ---------------------------------
|
// ---------------------------------
|
||||||
@ -253,7 +253,7 @@ inline int VersionNumber(void) {
|
|||||||
// ---------------------------------
|
// ---------------------------------
|
||||||
// function to perform reduction for Reducer
|
// function to perform reduction for Reducer
|
||||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
||||||
inline void ReducerSafe_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
inline void ReducerSafeImpl(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
||||||
const size_t kUnit = sizeof(DType);
|
const size_t kUnit = sizeof(DType);
|
||||||
const char *psrc = reinterpret_cast<const char*>(src_);
|
const char *psrc = reinterpret_cast<const char*>(src_);
|
||||||
char *pdst = reinterpret_cast<char*>(dst_);
|
char *pdst = reinterpret_cast<char*>(dst_);
|
||||||
@ -269,7 +269,7 @@ inline void ReducerSafe_(const void *src_, void *dst_, int len_, const MPI::Data
|
|||||||
}
|
}
|
||||||
// function to perform reduction for Reducer
|
// function to perform reduction for Reducer
|
||||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
|
||||||
inline void ReducerAlign_(const void *src_, void *dst_,
|
inline void ReducerAlignImpl(const void *src_, void *dst_,
|
||||||
int len_, const MPI::Datatype &dtype) {
|
int len_, const MPI::Datatype &dtype) {
|
||||||
const DType *psrc = reinterpret_cast<const DType*>(src_);
|
const DType *psrc = reinterpret_cast<const DType*>(src_);
|
||||||
DType *pdst = reinterpret_cast<DType*>(dst_);
|
DType *pdst = reinterpret_cast<DType*>(dst_);
|
||||||
@ -278,12 +278,12 @@ inline void ReducerAlign_(const void *src_, void *dst_,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
|
||||||
inline Reducer<DType, freduce>::Reducer(void) {
|
inline Reducer<DType, freduce>::Reducer() {
|
||||||
// it is safe to directly use handle for aligned data types
|
// it is safe to directly use handle for aligned data types
|
||||||
if (sizeof(DType) == 8 || sizeof(DType) == 4 || sizeof(DType) == 1) {
|
if (sizeof(DType) == 8 || sizeof(DType) == 4 || sizeof(DType) == 1) {
|
||||||
this->handle_.Init(ReducerAlign_<DType, freduce>, sizeof(DType));
|
this->handle_.Init(ReducerAlignImpl<DType, freduce>, sizeof(DType));
|
||||||
} else {
|
} else {
|
||||||
this->handle_.Init(ReducerSafe_<DType, freduce>, sizeof(DType));
|
this->handle_.Init(ReducerSafeImpl<DType, freduce>, sizeof(DType));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)> // NOLINT(*)
|
||||||
@ -298,8 +298,8 @@ inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
|
|||||||
}
|
}
|
||||||
// function to perform reduction for SerializeReducer
|
// function to perform reduction for SerializeReducer
|
||||||
template<typename DType>
|
template<typename DType>
|
||||||
inline void SerializeReducerFunc_(const void *src_, void *dst_,
|
inline void SerializeReducerFuncImpl(const void *src_, void *dst_,
|
||||||
int len_, const MPI::Datatype &dtype) {
|
int len_, const MPI::Datatype &dtype) {
|
||||||
int nbytes = engine::ReduceHandle::TypeSize(dtype);
|
int nbytes = engine::ReduceHandle::TypeSize(dtype);
|
||||||
// temp space
|
// temp space
|
||||||
for (int i = 0; i < len_; ++i) {
|
for (int i = 0; i < len_; ++i) {
|
||||||
@ -315,8 +315,8 @@ inline void SerializeReducerFunc_(const void *src_, void *dst_,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
template<typename DType>
|
template<typename DType>
|
||||||
inline SerializeReducer<DType>::SerializeReducer(void) {
|
inline SerializeReducer<DType>::SerializeReducer() {
|
||||||
handle_.Init(SerializeReducerFunc_<DType>, sizeof(DType));
|
handle_.Init(SerializeReducerFuncImpl<DType>, sizeof(DType));
|
||||||
}
|
}
|
||||||
// closure to call Allreduce
|
// closure to call Allreduce
|
||||||
template<typename DType>
|
template<typename DType>
|
||||||
@ -327,8 +327,8 @@ struct SerializeReduceClosure {
|
|||||||
void *prepare_arg;
|
void *prepare_arg;
|
||||||
std::string *p_buffer;
|
std::string *p_buffer;
|
||||||
// invoke the closure
|
// invoke the closure
|
||||||
inline void Run(void) {
|
inline void Run() {
|
||||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
|
||||||
for (size_t i = 0; i < count; ++i) {
|
for (size_t i = 0; i < count; ++i) {
|
||||||
utils::MemoryFixSizeBuffer fs(BeginPtr(*p_buffer) + i * max_nbyte, max_nbyte);
|
utils::MemoryFixSizeBuffer fs(BeginPtr(*p_buffer) + i * max_nbyte, max_nbyte);
|
||||||
sendrecvobj[i].Save(fs);
|
sendrecvobj[i].Save(fs);
|
||||||
@ -368,7 +368,7 @@ inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
|
|||||||
const char* _file,
|
const char* _file,
|
||||||
const int _line,
|
const int _line,
|
||||||
const char* _caller) {
|
const char* _caller) {
|
||||||
this->Allreduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun,
|
this->Allreduce(sendrecvbuf, count, InvokeLambda, &prepare_fun,
|
||||||
_file, _line, _caller);
|
_file, _line, _caller);
|
||||||
}
|
}
|
||||||
template<typename DType>
|
template<typename DType>
|
||||||
@ -378,7 +378,7 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
|||||||
const char* _file,
|
const char* _file,
|
||||||
const int _line,
|
const int _line,
|
||||||
const char* _caller) {
|
const char* _caller) {
|
||||||
this->Allreduce(sendrecvobj, max_nbytes, count, InvokeLambda_, &prepare_fun,
|
this->Allreduce(sendrecvobj, max_nbytes, count, InvokeLambda, &prepare_fun,
|
||||||
_file, _line, _caller);
|
_file, _line, _caller);
|
||||||
}
|
}
|
||||||
#endif // DMLC_USE_CXX11
|
#endif // DMLC_USE_CXX11
|
||||||
|
|||||||
@ -15,7 +15,7 @@
|
|||||||
#else
|
#else
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
#include <netdb.h>
|
#include <netdb.h>
|
||||||
#include <errno.h>
|
#include <cerrno>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
#include <arpa/inet.h>
|
#include <arpa/inet.h>
|
||||||
#include <netinet/in.h>
|
#include <netinet/in.h>
|
||||||
@ -39,9 +39,9 @@ static inline int poll(struct pollfd *pfd, int nfds,
|
|||||||
int timeout) { return WSAPoll ( pfd, nfds, timeout ); }
|
int timeout) { return WSAPoll ( pfd, nfds, timeout ); }
|
||||||
#else
|
#else
|
||||||
#include <sys/poll.h>
|
#include <sys/poll.h>
|
||||||
typedef int SOCKET;
|
using SOCKET = int;
|
||||||
typedef size_t sock_size_t;
|
using sock_size_t = size_t; // NOLINT
|
||||||
const int INVALID_SOCKET = -1;
|
const int kInvalidSocket = -1;
|
||||||
#endif // defined(_WIN32)
|
#endif // defined(_WIN32)
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
@ -50,11 +50,11 @@ namespace utils {
|
|||||||
struct SockAddr {
|
struct SockAddr {
|
||||||
sockaddr_in addr;
|
sockaddr_in addr;
|
||||||
// constructor
|
// constructor
|
||||||
SockAddr(void) {}
|
SockAddr() = default;
|
||||||
SockAddr(const char *url, int port) {
|
SockAddr(const char *url, int port) {
|
||||||
this->Set(url, port);
|
this->Set(url, port);
|
||||||
}
|
}
|
||||||
inline static std::string GetHostName(void) {
|
inline static std::string GetHostName() {
|
||||||
std::string buf; buf.resize(256);
|
std::string buf; buf.resize(256);
|
||||||
utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
|
utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
|
||||||
return std::string(buf.c_str());
|
return std::string(buf.c_str());
|
||||||
@ -69,20 +69,20 @@ struct SockAddr {
|
|||||||
memset(&hints, 0, sizeof(hints));
|
memset(&hints, 0, sizeof(hints));
|
||||||
hints.ai_family = AF_INET;
|
hints.ai_family = AF_INET;
|
||||||
hints.ai_protocol = SOCK_STREAM;
|
hints.ai_protocol = SOCK_STREAM;
|
||||||
addrinfo *res = NULL;
|
addrinfo *res = nullptr;
|
||||||
int sig = getaddrinfo(host, NULL, &hints, &res);
|
int sig = getaddrinfo(host, nullptr, &hints, &res);
|
||||||
Check(sig == 0 && res != NULL, "cannot obtain address of %s", host);
|
Check(sig == 0 && res != nullptr, "cannot obtain address of %s", host);
|
||||||
Check(res->ai_family == AF_INET, "Does not support IPv6");
|
Check(res->ai_family == AF_INET, "Does not support IPv6");
|
||||||
memcpy(&addr, res->ai_addr, res->ai_addrlen);
|
memcpy(&addr, res->ai_addr, res->ai_addrlen);
|
||||||
addr.sin_port = htons(port);
|
addr.sin_port = htons(port);
|
||||||
freeaddrinfo(res);
|
freeaddrinfo(res);
|
||||||
}
|
}
|
||||||
/*! \brief return port of the address*/
|
/*! \brief return port of the address*/
|
||||||
inline int port(void) const {
|
inline int Port() const {
|
||||||
return ntohs(addr.sin_port);
|
return ntohs(addr.sin_port);
|
||||||
}
|
}
|
||||||
/*! \return a string representation of the address */
|
/*! \return a string representation of the address */
|
||||||
inline std::string AddrStr(void) const {
|
inline std::string AddrStr() const {
|
||||||
std::string buf; buf.resize(256);
|
std::string buf; buf.resize(256);
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
|
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
|
||||||
@ -91,7 +91,7 @@ struct SockAddr {
|
|||||||
const char *s = inet_ntop(AF_INET, &addr.sin_addr,
|
const char *s = inet_ntop(AF_INET, &addr.sin_addr,
|
||||||
&buf[0], buf.length());
|
&buf[0], buf.length());
|
||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
Assert(s != NULL, "cannot decode address");
|
Assert(s != nullptr, "cannot decode address");
|
||||||
return std::string(s);
|
return std::string(s);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -104,13 +104,13 @@ class Socket {
|
|||||||
/*! \brief the file descriptor of socket */
|
/*! \brief the file descriptor of socket */
|
||||||
SOCKET sockfd;
|
SOCKET sockfd;
|
||||||
// default conversion to int
|
// default conversion to int
|
||||||
inline operator SOCKET() const {
|
operator SOCKET() const { // NOLINT
|
||||||
return sockfd;
|
return sockfd;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \return last error of socket operation
|
* \return last error of socket operation
|
||||||
*/
|
*/
|
||||||
inline static int GetLastError(void) {
|
inline static int GetLastError() {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
return WSAGetLastError();
|
return WSAGetLastError();
|
||||||
#else
|
#else
|
||||||
@ -118,7 +118,7 @@ class Socket {
|
|||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
}
|
}
|
||||||
/*! \return whether last error was would block */
|
/*! \return whether last error was would block */
|
||||||
inline static bool LastErrorWouldBlock(void) {
|
inline static bool LastErrorWouldBlock() {
|
||||||
int errsv = GetLastError();
|
int errsv = GetLastError();
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
return errsv == WSAEWOULDBLOCK;
|
return errsv == WSAEWOULDBLOCK;
|
||||||
@ -130,7 +130,7 @@ class Socket {
|
|||||||
* \brief start up the socket module
|
* \brief start up the socket module
|
||||||
* call this before using the sockets
|
* call this before using the sockets
|
||||||
*/
|
*/
|
||||||
inline static void Startup(void) {
|
inline static void Startup() {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
WSADATA wsa_data;
|
WSADATA wsa_data;
|
||||||
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
|
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
|
||||||
@ -145,7 +145,7 @@ class Socket {
|
|||||||
/*!
|
/*!
|
||||||
* \brief shutdown the socket module after use, all sockets need to be closed
|
* \brief shutdown the socket module after use, all sockets need to be closed
|
||||||
*/
|
*/
|
||||||
inline static void Finalize(void) {
|
inline static void Finalize() {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
WSACleanup();
|
WSACleanup();
|
||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
@ -214,7 +214,7 @@ class Socket {
|
|||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
/*! \brief get last error code if any */
|
/*! \brief get last error code if any */
|
||||||
inline int GetSockError(void) const {
|
inline int GetSockError() const {
|
||||||
int error = 0;
|
int error = 0;
|
||||||
socklen_t len = sizeof(error);
|
socklen_t len = sizeof(error);
|
||||||
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
|
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
|
||||||
@ -224,25 +224,25 @@ class Socket {
|
|||||||
return error;
|
return error;
|
||||||
}
|
}
|
||||||
/*! \brief check if anything bad happens */
|
/*! \brief check if anything bad happens */
|
||||||
inline bool BadSocket(void) const {
|
inline bool BadSocket() const {
|
||||||
if (IsClosed()) return true;
|
if (IsClosed()) return true;
|
||||||
int err = GetSockError();
|
int err = GetSockError();
|
||||||
if (err == EBADF || err == EINTR) return true;
|
if (err == EBADF || err == EINTR) return true;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
/*! \brief check if socket is already closed */
|
/*! \brief check if socket is already closed */
|
||||||
inline bool IsClosed(void) const {
|
inline bool IsClosed() const {
|
||||||
return sockfd == INVALID_SOCKET;
|
return sockfd == kInvalidSocket;
|
||||||
}
|
}
|
||||||
/*! \brief close the socket */
|
/*! \brief close the socket */
|
||||||
inline void Close(void) {
|
inline void Close() {
|
||||||
if (sockfd != INVALID_SOCKET) {
|
if (sockfd != kInvalidSocket) {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
closesocket(sockfd);
|
closesocket(sockfd);
|
||||||
#else
|
#else
|
||||||
close(sockfd);
|
close(sockfd);
|
||||||
#endif
|
#endif
|
||||||
sockfd = INVALID_SOCKET;
|
sockfd = kInvalidSocket;
|
||||||
} else {
|
} else {
|
||||||
Error("Socket::Close double close the socket or close without create");
|
Error("Socket::Close double close the socket or close without create");
|
||||||
}
|
}
|
||||||
@ -268,7 +268,7 @@ class Socket {
|
|||||||
class TCPSocket : public Socket{
|
class TCPSocket : public Socket{
|
||||||
public:
|
public:
|
||||||
// constructor
|
// constructor
|
||||||
TCPSocket(void) : Socket(INVALID_SOCKET) {
|
TCPSocket() : Socket(kInvalidSocket) {
|
||||||
}
|
}
|
||||||
explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) {
|
explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) {
|
||||||
}
|
}
|
||||||
@ -297,7 +297,7 @@ class TCPSocket : public Socket{
|
|||||||
*/
|
*/
|
||||||
inline void Create(int af = PF_INET) {
|
inline void Create(int af = PF_INET) {
|
||||||
sockfd = socket(PF_INET, SOCK_STREAM, 0);
|
sockfd = socket(PF_INET, SOCK_STREAM, 0);
|
||||||
if (sockfd == INVALID_SOCKET) {
|
if (sockfd == kInvalidSocket) {
|
||||||
Socket::Error("Create");
|
Socket::Error("Create");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -309,9 +309,9 @@ class TCPSocket : public Socket{
|
|||||||
listen(sockfd, backlog);
|
listen(sockfd, backlog);
|
||||||
}
|
}
|
||||||
/*! \brief get a new connection */
|
/*! \brief get a new connection */
|
||||||
TCPSocket Accept(void) {
|
TCPSocket Accept() {
|
||||||
SOCKET newfd = accept(sockfd, NULL, NULL);
|
SOCKET newfd = accept(sockfd, nullptr, nullptr);
|
||||||
if (newfd == INVALID_SOCKET) {
|
if (newfd == kInvalidSocket) {
|
||||||
Socket::Error("Accept");
|
Socket::Error("Accept");
|
||||||
}
|
}
|
||||||
return TCPSocket(newfd);
|
return TCPSocket(newfd);
|
||||||
@ -320,7 +320,7 @@ class TCPSocket : public Socket{
|
|||||||
* \brief decide whether the socket is at OOB mark
|
* \brief decide whether the socket is at OOB mark
|
||||||
* \return 1 if at mark, 0 if not, -1 if an error occured
|
* \return 1 if at mark, 0 if not, -1 if an error occured
|
||||||
*/
|
*/
|
||||||
inline int AtMark(void) const {
|
inline int AtMark() const {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
unsigned long atmark; // NOLINT(*)
|
unsigned long atmark; // NOLINT(*)
|
||||||
if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
|
if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
|
||||||
|
|||||||
@ -1,87 +0,0 @@
|
|||||||
/*!
|
|
||||||
* Copyright (c) 2015 by Contributors
|
|
||||||
* \file thread_local.h
|
|
||||||
* \brief Common utility for thread local storage.
|
|
||||||
*/
|
|
||||||
#ifndef RABIT_INTERNAL_THREAD_LOCAL_H_
|
|
||||||
#define RABIT_INTERNAL_THREAD_LOCAL_H_
|
|
||||||
|
|
||||||
#include "../include/dmlc/base.h"
|
|
||||||
|
|
||||||
#if DMLC_ENABLE_STD_THREAD
|
|
||||||
#include <mutex>
|
|
||||||
#endif // DMLC_ENABLE_STD_THREAD
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
namespace rabit {
|
|
||||||
|
|
||||||
// macro hanlding for threadlocal variables
|
|
||||||
#ifdef __GNUC__
|
|
||||||
#define MX_TREAD_LOCAL __thread
|
|
||||||
#elif __STDC_VERSION__ >= 201112L
|
|
||||||
#define MX_TREAD_LOCAL _Thread_local
|
|
||||||
#elif defined(_MSC_VER)
|
|
||||||
#define MX_TREAD_LOCAL __declspec(thread)
|
|
||||||
#endif // __GNUC__
|
|
||||||
|
|
||||||
#ifndef MX_TREAD_LOCAL
|
|
||||||
#message("Warning: Threadlocal is not enabled");
|
|
||||||
#endif // MX_TREAD_LOCAL
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief A threadlocal store to store threadlocal variables.
|
|
||||||
* Will return a thread local singleton of type T
|
|
||||||
* \tparam T the type we like to store
|
|
||||||
*/
|
|
||||||
template<typename T>
|
|
||||||
class ThreadLocalStore {
|
|
||||||
public:
|
|
||||||
/*! \return get a thread local singleton */
|
|
||||||
static T* Get() {
|
|
||||||
static MX_TREAD_LOCAL T* ptr = nullptr;
|
|
||||||
if (ptr == nullptr) {
|
|
||||||
ptr = new T();
|
|
||||||
Singleton()->RegisterDelete(ptr);
|
|
||||||
}
|
|
||||||
return ptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
/*! \brief constructor */
|
|
||||||
ThreadLocalStore() {}
|
|
||||||
/*! \brief destructor */
|
|
||||||
~ThreadLocalStore() {
|
|
||||||
for (size_t i = 0; i < data_.size(); ++i) {
|
|
||||||
delete data_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/*! \return singleton of the store */
|
|
||||||
static ThreadLocalStore<T> *Singleton() {
|
|
||||||
static ThreadLocalStore<T> inst;
|
|
||||||
return &inst;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief register str for internal deletion
|
|
||||||
* \param str the string pointer
|
|
||||||
*/
|
|
||||||
void RegisterDelete(T *str) {
|
|
||||||
#if DMLC_ENABLE_STD_THREAD
|
|
||||||
std::unique_lock<std::mutex> lock(mutex_);
|
|
||||||
data_.push_back(str);
|
|
||||||
lock.unlock();
|
|
||||||
#else
|
|
||||||
data_.push_back(str);
|
|
||||||
#endif // DMLC_ENABLE_STD_THREAD
|
|
||||||
}
|
|
||||||
|
|
||||||
#if DMLC_ENABLE_STD_THREAD
|
|
||||||
/*! \brief internal mutex */
|
|
||||||
std::mutex mutex_;
|
|
||||||
#endif // DMLC_ENABLE_STD_THREAD
|
|
||||||
/*!\brief internal data */
|
|
||||||
std::vector<T*> data_;
|
|
||||||
};
|
|
||||||
} // namespace rabit
|
|
||||||
#endif // RABIT_INTERNAL_THREAD_LOCAL_H_
|
|
||||||
@ -6,7 +6,7 @@
|
|||||||
*/
|
*/
|
||||||
#ifndef RABIT_INTERNAL_TIMER_H_
|
#ifndef RABIT_INTERNAL_TIMER_H_
|
||||||
#define RABIT_INTERNAL_TIMER_H_
|
#define RABIT_INTERNAL_TIMER_H_
|
||||||
#include <time.h>
|
#include <ctime>
|
||||||
#ifdef __MACH__
|
#ifdef __MACH__
|
||||||
#include <mach/clock.h>
|
#include <mach/clock.h>
|
||||||
#include <mach/mach.h>
|
#include <mach/mach.h>
|
||||||
@ -18,7 +18,7 @@ namespace utils {
|
|||||||
/*!
|
/*!
|
||||||
* \brief return time in seconds, not cross platform, avoid to use this in most places
|
* \brief return time in seconds, not cross platform, avoid to use this in most places
|
||||||
*/
|
*/
|
||||||
inline double GetTime(void) {
|
inline double GetTime() {
|
||||||
#ifdef __MACH__
|
#ifdef __MACH__
|
||||||
clock_serv_t cclock;
|
clock_serv_t cclock;
|
||||||
mach_timespec_t mts;
|
mach_timespec_t mts;
|
||||||
|
|||||||
@ -8,7 +8,7 @@
|
|||||||
#define RABIT_INTERNAL_UTILS_H_
|
#define RABIT_INTERNAL_UTILS_H_
|
||||||
|
|
||||||
#include <rabit/base.h>
|
#include <rabit/base.h>
|
||||||
#include <string.h>
|
#include <cstring>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
@ -48,15 +48,7 @@ extern "C" {
|
|||||||
}
|
}
|
||||||
#endif // _MSC_VER
|
#endif // _MSC_VER
|
||||||
|
|
||||||
#ifdef _MSC_VER
|
#include <cinttypes>
|
||||||
typedef unsigned char uint8_t;
|
|
||||||
typedef unsigned __int16 uint16_t;
|
|
||||||
typedef unsigned __int32 uint32_t;
|
|
||||||
typedef unsigned __int64 uint64_t;
|
|
||||||
typedef __int64 int64_t;
|
|
||||||
#else
|
|
||||||
#include <inttypes.h>
|
|
||||||
#endif // _MSC_VER
|
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
/*! \brief namespace for helper utils of the project */
|
/*! \brief namespace for helper utils of the project */
|
||||||
@ -184,7 +176,7 @@ inline void Error(const char *fmt, ...) {
|
|||||||
/*! \brief replace fopen, report error when the file open fails */
|
/*! \brief replace fopen, report error when the file open fails */
|
||||||
inline std::FILE *FopenCheck(const char *fname, const char *flag) {
|
inline std::FILE *FopenCheck(const char *fname, const char *flag) {
|
||||||
std::FILE *fp = fopen64(fname, flag);
|
std::FILE *fp = fopen64(fname, flag);
|
||||||
Check(fp != NULL, "can not open file \"%s\"\n", fname);
|
Check(fp != nullptr, "can not open file \"%s\"\n", fname);
|
||||||
return fp;
|
return fp;
|
||||||
}
|
}
|
||||||
} // namespace utils
|
} // namespace utils
|
||||||
@ -193,7 +185,7 @@ inline std::FILE *FopenCheck(const char *fname, const char *flag) {
|
|||||||
template<typename T>
|
template<typename T>
|
||||||
inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*)
|
inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*)
|
||||||
if (vec.size() == 0) {
|
if (vec.size() == 0) {
|
||||||
return NULL;
|
return nullptr;
|
||||||
} else {
|
} else {
|
||||||
return &vec[0];
|
return &vec[0];
|
||||||
}
|
}
|
||||||
@ -202,17 +194,17 @@ inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*)
|
|||||||
template<typename T>
|
template<typename T>
|
||||||
inline const T *BeginPtr(const std::vector<T> &vec) { // NOLINT(*)
|
inline const T *BeginPtr(const std::vector<T> &vec) { // NOLINT(*)
|
||||||
if (vec.size() == 0) {
|
if (vec.size() == 0) {
|
||||||
return NULL;
|
return nullptr;
|
||||||
} else {
|
} else {
|
||||||
return &vec[0];
|
return &vec[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
inline char* BeginPtr(std::string &str) { // NOLINT(*)
|
inline char* BeginPtr(std::string &str) { // NOLINT(*)
|
||||||
if (str.length() == 0) return NULL;
|
if (str.length() == 0) return nullptr;
|
||||||
return &str[0];
|
return &str[0];
|
||||||
}
|
}
|
||||||
inline const char* BeginPtr(const std::string &str) {
|
inline const char* BeginPtr(const std::string &str) {
|
||||||
if (str.length() == 0) return NULL;
|
if (str.length() == 0) return nullptr;
|
||||||
return &str[0];
|
return &str[0];
|
||||||
}
|
}
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|||||||
@ -53,12 +53,12 @@ namespace rabit {
|
|||||||
* \brief defines stream used in rabit
|
* \brief defines stream used in rabit
|
||||||
* see definition of Stream in dmlc/io.h
|
* see definition of Stream in dmlc/io.h
|
||||||
*/
|
*/
|
||||||
typedef dmlc::Stream Stream;
|
using Stream = dmlc::Stream;
|
||||||
/*!
|
/*!
|
||||||
* \brief defines serializable objects used in rabit
|
* \brief defines serializable objects used in rabit
|
||||||
* see definition of Serializable in dmlc/io.h
|
* see definition of Serializable in dmlc/io.h
|
||||||
*/
|
*/
|
||||||
typedef dmlc::Serializable Serializable;
|
using Serializable = dmlc::Serializable;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief reduction operators namespace
|
* \brief reduction operators namespace
|
||||||
@ -199,8 +199,8 @@ inline void Broadcast(std::string *sendrecv_data, int root,
|
|||||||
*/
|
*/
|
||||||
template<typename OP, typename DType>
|
template<typename OP, typename DType>
|
||||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||||
void (*prepare_fun)(void *) = NULL,
|
void (*prepare_fun)(void *) = nullptr,
|
||||||
void *prepare_arg = NULL,
|
void *prepare_arg = nullptr,
|
||||||
const char* _file = _FILE,
|
const char* _file = _FILE,
|
||||||
const int _line = _LINE,
|
const int _line = _LINE,
|
||||||
const char* _caller = _CALLER);
|
const char* _caller = _CALLER);
|
||||||
@ -291,7 +291,7 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
|
|||||||
* \sa CheckPoint, VersionNumber
|
* \sa CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
inline int LoadCheckPoint(Serializable *global_model,
|
inline int LoadCheckPoint(Serializable *global_model,
|
||||||
Serializable *local_model = NULL);
|
Serializable *local_model = nullptr);
|
||||||
/*!
|
/*!
|
||||||
* \brief checkpoints the model, meaning a stage of execution has finished.
|
* \brief checkpoints the model, meaning a stage of execution has finished.
|
||||||
* every time we call check point, a version number will be increased by one
|
* every time we call check point, a version number will be increased by one
|
||||||
@ -307,7 +307,7 @@ inline int LoadCheckPoint(Serializable *global_model,
|
|||||||
* \sa LoadCheckPoint, VersionNumber
|
* \sa LoadCheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
inline void CheckPoint(const Serializable *global_model,
|
inline void CheckPoint(const Serializable *global_model,
|
||||||
const Serializable *local_model = NULL);
|
const Serializable *local_model = nullptr);
|
||||||
/*!
|
/*!
|
||||||
* \brief This function can be used to replace CheckPoint for global_model only,
|
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||||
* when certain condition is met (see detailed explanation).
|
* when certain condition is met (see detailed explanation).
|
||||||
@ -366,8 +366,8 @@ class Reducer {
|
|||||||
* \param _caller caller function name used to generate unique cache key
|
* \param _caller caller function name used to generate unique cache key
|
||||||
*/
|
*/
|
||||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||||
void (*prepare_fun)(void *) = NULL,
|
void (*prepare_fun)(void *) = nullptr,
|
||||||
void *prepare_arg = NULL,
|
void *prepare_arg = nullptr,
|
||||||
const char* _file = _FILE,
|
const char* _file = _FILE,
|
||||||
const int _line = _LINE,
|
const int _line = _LINE,
|
||||||
const char* _caller = _CALLER);
|
const char* _caller = _CALLER);
|
||||||
@ -422,8 +422,8 @@ class SerializeReducer {
|
|||||||
*/
|
*/
|
||||||
inline void Allreduce(DType *sendrecvobj,
|
inline void Allreduce(DType *sendrecvobj,
|
||||||
size_t max_nbyte, size_t count,
|
size_t max_nbyte, size_t count,
|
||||||
void (*prepare_fun)(void *) = NULL,
|
void (*prepare_fun)(void *) = nullptr,
|
||||||
void *prepare_arg = NULL,
|
void *prepare_arg = nullptr,
|
||||||
const char* _file = _FILE,
|
const char* _file = _FILE,
|
||||||
const int _line = _LINE,
|
const int _line = _LINE,
|
||||||
const char* _caller = _CALLER);
|
const char* _caller = _CALLER);
|
||||||
|
|||||||
@ -15,12 +15,12 @@ namespace rabit {
|
|||||||
* \brief defines stream used in rabit
|
* \brief defines stream used in rabit
|
||||||
* see definition of Stream in dmlc/io.h
|
* see definition of Stream in dmlc/io.h
|
||||||
*/
|
*/
|
||||||
typedef dmlc::Stream Stream;
|
using Stream = dmlc::Stream ;
|
||||||
/*!
|
/*!
|
||||||
* \brief defines serializable objects used in rabit
|
* \brief defines serializable objects used in rabit
|
||||||
* see definition of Serializable in dmlc/io.h
|
* see definition of Serializable in dmlc/io.h
|
||||||
*/
|
*/
|
||||||
typedef dmlc::Serializable Serializable;
|
using Serializable = dmlc::Serializable;
|
||||||
|
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
#endif // RABIT_SERIALIZABLE_H_
|
#endif // RABIT_SERIALIZABLE_H_
|
||||||
|
|||||||
@ -15,7 +15,7 @@
|
|||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace engine {
|
namespace engine {
|
||||||
// constructor
|
// constructor
|
||||||
AllreduceBase::AllreduceBase(void) {
|
AllreduceBase::AllreduceBase() {
|
||||||
tracker_uri = "NULL";
|
tracker_uri = "NULL";
|
||||||
tracker_port = 9000;
|
tracker_port = 9000;
|
||||||
host_uri = "";
|
host_uri = "";
|
||||||
@ -24,7 +24,7 @@ AllreduceBase::AllreduceBase(void) {
|
|||||||
rank = 0;
|
rank = 0;
|
||||||
world_size = -1;
|
world_size = -1;
|
||||||
connect_retry = 5;
|
connect_retry = 5;
|
||||||
hadoop_mode = 0;
|
hadoop_mode = false;
|
||||||
version_number = 0;
|
version_number = 0;
|
||||||
// 32 K items
|
// 32 K items
|
||||||
reduce_ring_mincount = 32 << 10;
|
reduce_ring_mincount = 32 << 10;
|
||||||
@ -32,27 +32,27 @@ AllreduceBase::AllreduceBase(void) {
|
|||||||
tree_reduce_minsize = 1 << 20;
|
tree_reduce_minsize = 1 << 20;
|
||||||
// tracker URL
|
// tracker URL
|
||||||
task_id = "NULL";
|
task_id = "NULL";
|
||||||
err_link = NULL;
|
err_link = nullptr;
|
||||||
dmlc_role = "worker";
|
dmlc_role = "worker";
|
||||||
this->SetParam("rabit_reduce_buffer", "256MB");
|
this->SetParam("rabit_reduce_buffer", "256MB");
|
||||||
// setup possible enviroment variable of interest
|
// setup possible enviroment variable of interest
|
||||||
// include dmlc support direct variables
|
// include dmlc support direct variables
|
||||||
env_vars.push_back("DMLC_TASK_ID");
|
env_vars.emplace_back("DMLC_TASK_ID");
|
||||||
env_vars.push_back("DMLC_ROLE");
|
env_vars.emplace_back("DMLC_ROLE");
|
||||||
env_vars.push_back("DMLC_NUM_ATTEMPT");
|
env_vars.emplace_back("DMLC_NUM_ATTEMPT");
|
||||||
env_vars.push_back("DMLC_TRACKER_URI");
|
env_vars.emplace_back("DMLC_TRACKER_URI");
|
||||||
env_vars.push_back("DMLC_TRACKER_PORT");
|
env_vars.emplace_back("DMLC_TRACKER_PORT");
|
||||||
env_vars.push_back("DMLC_WORKER_CONNECT_RETRY");
|
env_vars.emplace_back("DMLC_WORKER_CONNECT_RETRY");
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialization function
|
// initialization function
|
||||||
bool AllreduceBase::Init(int argc, char* argv[]) {
|
bool AllreduceBase::Init(int argc, char* argv[]) {
|
||||||
// setup from enviroment variables
|
// setup from enviroment variables
|
||||||
// handler to get variables from env
|
// handler to get variables from env
|
||||||
for (size_t i = 0; i < env_vars.size(); ++i) {
|
for (auto & env_var : env_vars) {
|
||||||
const char *value = getenv(env_vars[i].c_str());
|
const char *value = getenv(env_var.c_str());
|
||||||
if (value != NULL) {
|
if (value != nullptr) {
|
||||||
this->SetParam(env_vars[i].c_str(), value);
|
this->SetParam(env_var.c_str(), value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// pass in arguments override env variable.
|
// pass in arguments override env variable.
|
||||||
@ -66,35 +66,35 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
|
|||||||
{
|
{
|
||||||
// handling for hadoop
|
// handling for hadoop
|
||||||
const char *task_id = getenv("mapred_tip_id");
|
const char *task_id = getenv("mapred_tip_id");
|
||||||
if (task_id == NULL) {
|
if (task_id == nullptr) {
|
||||||
task_id = getenv("mapreduce_task_id");
|
task_id = getenv("mapreduce_task_id");
|
||||||
}
|
}
|
||||||
if (hadoop_mode) {
|
if (hadoop_mode) {
|
||||||
utils::Check(task_id != NULL,
|
utils::Check(task_id != nullptr,
|
||||||
"hadoop_mode is set but cannot find mapred_task_id");
|
"hadoop_mode is set but cannot find mapred_task_id");
|
||||||
}
|
}
|
||||||
if (task_id != NULL) {
|
if (task_id != nullptr) {
|
||||||
this->SetParam("rabit_task_id", task_id);
|
this->SetParam("rabit_task_id", task_id);
|
||||||
this->SetParam("rabit_hadoop_mode", "1");
|
this->SetParam("rabit_hadoop_mode", "1");
|
||||||
}
|
}
|
||||||
const char *attempt_id = getenv("mapred_task_id");
|
const char *attempt_id = getenv("mapred_task_id");
|
||||||
if (attempt_id != 0) {
|
if (attempt_id != nullptr) {
|
||||||
const char *att = strrchr(attempt_id, '_');
|
const char *att = strrchr(attempt_id, '_');
|
||||||
int num_trial;
|
int num_trial;
|
||||||
if (att != NULL && sscanf(att + 1, "%d", &num_trial) == 1) {
|
if (att != nullptr && sscanf(att + 1, "%d", &num_trial) == 1) {
|
||||||
this->SetParam("rabit_num_trial", att + 1);
|
this->SetParam("rabit_num_trial", att + 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// handling for hadoop
|
// handling for hadoop
|
||||||
const char *num_task = getenv("mapred_map_tasks");
|
const char *num_task = getenv("mapred_map_tasks");
|
||||||
if (num_task == NULL) {
|
if (num_task == nullptr) {
|
||||||
num_task = getenv("mapreduce_job_maps");
|
num_task = getenv("mapreduce_job_maps");
|
||||||
}
|
}
|
||||||
if (hadoop_mode) {
|
if (hadoop_mode) {
|
||||||
utils::Check(num_task != NULL,
|
utils::Check(num_task != nullptr,
|
||||||
"hadoop_mode is set but cannot find mapred_map_tasks");
|
"hadoop_mode is set but cannot find mapred_map_tasks");
|
||||||
}
|
}
|
||||||
if (num_task != NULL) {
|
if (num_task != nullptr) {
|
||||||
this->SetParam("rabit_world_size", num_task);
|
this->SetParam("rabit_world_size", num_task);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -115,10 +115,10 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
|
|||||||
return this->ReConnectLinks();
|
return this->ReConnectLinks();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AllreduceBase::Shutdown(void) {
|
bool AllreduceBase::Shutdown() {
|
||||||
try {
|
try {
|
||||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
for (auto & all_link : all_links) {
|
||||||
all_links[i].sock.Close();
|
all_link.sock.Close();
|
||||||
}
|
}
|
||||||
all_links.clear();
|
all_links.clear();
|
||||||
tree_links.plinks.clear();
|
tree_links.plinks.clear();
|
||||||
@ -208,17 +208,18 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
utils::Assert(timeout_sec >= 0, "rabit_timeout_sec should be non negative second");
|
utils::Assert(timeout_sec >= 0, "rabit_timeout_sec should be non negative second");
|
||||||
}
|
}
|
||||||
if (!strcmp(name, "rabit_enable_tcp_no_delay")) {
|
if (!strcmp(name, "rabit_enable_tcp_no_delay")) {
|
||||||
if (!strcmp(val, "true"))
|
if (!strcmp(val, "true")) {
|
||||||
rabit_enable_tcp_no_delay = true;
|
rabit_enable_tcp_no_delay = true;
|
||||||
else
|
} else {
|
||||||
rabit_enable_tcp_no_delay = false;
|
rabit_enable_tcp_no_delay = false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief initialize connection to the tracker
|
* \brief initialize connection to the tracker
|
||||||
* \return a socket that initializes the connection
|
* \return a socket that initializes the connection
|
||||||
*/
|
*/
|
||||||
utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
|
utils::TCPSocket AllreduceBase::ConnectTracker() const {
|
||||||
int magic = kMagic;
|
int magic = kMagic;
|
||||||
// get information from tracker
|
// get information from tracker
|
||||||
utils::TCPSocket tracker;
|
utils::TCPSocket tracker;
|
||||||
@ -241,7 +242,7 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
} while (1);
|
} while (true);
|
||||||
|
|
||||||
using utils::Assert;
|
using utils::Assert;
|
||||||
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
|
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
|
||||||
@ -320,19 +321,19 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
do {
|
do {
|
||||||
// send over good links
|
// send over good links
|
||||||
std::vector<int> good_link;
|
std::vector<int> good_link;
|
||||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
for (auto & all_link : all_links) {
|
||||||
if (!all_links[i].sock.BadSocket()) {
|
if (!all_link.sock.BadSocket()) {
|
||||||
good_link.push_back(static_cast<int>(all_links[i].rank));
|
good_link.push_back(static_cast<int>(all_link.rank));
|
||||||
} else {
|
} else {
|
||||||
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
|
if (!all_link.sock.IsClosed()) all_link.sock.Close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int ngood = static_cast<int>(good_link.size());
|
int ngood = static_cast<int>(good_link.size());
|
||||||
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
|
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
|
||||||
"ReConnectLink failure 5");
|
"ReConnectLink failure 5");
|
||||||
for (size_t i = 0; i < good_link.size(); ++i) {
|
for (int & i : good_link) {
|
||||||
Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \
|
Assert(tracker.SendAll(&i, sizeof(i)) == \
|
||||||
sizeof(good_link[i]), "ReConnectLink failure 6");
|
sizeof(i), "ReConnectLink failure 6");
|
||||||
}
|
}
|
||||||
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
||||||
"ReConnectLink failure 7");
|
"ReConnectLink failure 7");
|
||||||
@ -362,11 +363,11 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
utils::Check(hrank == r.rank,
|
utils::Check(hrank == r.rank,
|
||||||
"ReConnectLink failure, link rank inconsistent");
|
"ReConnectLink failure, link rank inconsistent");
|
||||||
bool match = false;
|
bool match = false;
|
||||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
for (auto & all_link : all_links) {
|
||||||
if (all_links[i].rank == hrank) {
|
if (all_link.rank == hrank) {
|
||||||
Assert(all_links[i].sock.IsClosed(),
|
Assert(all_link.sock.IsClosed(),
|
||||||
"Override a link that is active");
|
"Override a link that is active");
|
||||||
all_links[i].sock = r.sock;
|
all_link.sock = r.sock;
|
||||||
match = true;
|
match = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -390,11 +391,11 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
|
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
|
||||||
"ReConnectLink failure 15");
|
"ReConnectLink failure 15");
|
||||||
bool match = false;
|
bool match = false;
|
||||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
for (auto & all_link : all_links) {
|
||||||
if (all_links[i].rank == r.rank) {
|
if (all_link.rank == r.rank) {
|
||||||
utils::Assert(all_links[i].sock.IsClosed(),
|
utils::Assert(all_link.sock.IsClosed(),
|
||||||
"Override a link that is active");
|
"Override a link that is active");
|
||||||
all_links[i].sock = r.sock;
|
all_link.sock = r.sock;
|
||||||
match = true;
|
match = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -406,29 +407,29 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
// setup tree links and ring structure
|
// setup tree links and ring structure
|
||||||
tree_links.plinks.clear();
|
tree_links.plinks.clear();
|
||||||
int tcpNoDelay = 1;
|
int tcpNoDelay = 1;
|
||||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
for (auto & all_link : all_links) {
|
||||||
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
|
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
|
||||||
// set the socket to non-blocking mode, enable TCP keepalive
|
// set the socket to non-blocking mode, enable TCP keepalive
|
||||||
all_links[i].sock.SetNonBlock(true);
|
all_link.sock.SetNonBlock(true);
|
||||||
all_links[i].sock.SetKeepAlive(true);
|
all_link.sock.SetKeepAlive(true);
|
||||||
if (rabit_enable_tcp_no_delay) {
|
if (rabit_enable_tcp_no_delay) {
|
||||||
setsockopt(all_links[i].sock, IPPROTO_TCP,
|
setsockopt(all_link.sock, IPPROTO_TCP,
|
||||||
TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay));
|
TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay));
|
||||||
}
|
}
|
||||||
if (tree_neighbors.count(all_links[i].rank) != 0) {
|
if (tree_neighbors.count(all_link.rank) != 0) {
|
||||||
if (all_links[i].rank == parent_rank) {
|
if (all_link.rank == parent_rank) {
|
||||||
parent_index = static_cast<int>(tree_links.plinks.size());
|
parent_index = static_cast<int>(tree_links.plinks.size());
|
||||||
}
|
}
|
||||||
tree_links.plinks.push_back(&all_links[i]);
|
tree_links.plinks.push_back(&all_link);
|
||||||
}
|
}
|
||||||
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
|
if (all_link.rank == prev_rank) ring_prev = &all_link;
|
||||||
if (all_links[i].rank == next_rank) ring_next = &all_links[i];
|
if (all_link.rank == next_rank) ring_next = &all_link;
|
||||||
}
|
}
|
||||||
Assert(parent_rank == -1 || parent_index != -1,
|
Assert(parent_rank == -1 || parent_index != -1,
|
||||||
"cannot find parent in the link");
|
"cannot find parent in the link");
|
||||||
Assert(prev_rank == -1 || ring_prev != NULL,
|
Assert(prev_rank == -1 || ring_prev != nullptr,
|
||||||
"cannot find prev ring in the link");
|
"cannot find prev ring in the link");
|
||||||
Assert(next_rank == -1 || ring_next != NULL,
|
Assert(next_rank == -1 || ring_next != nullptr,
|
||||||
"cannot find next ring in the link");
|
"cannot find next ring in the link");
|
||||||
return true;
|
return true;
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
@ -479,11 +480,11 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
|||||||
size_t count,
|
size_t count,
|
||||||
ReduceFunction reducer) {
|
ReduceFunction reducer) {
|
||||||
RefLinkVector &links = tree_links;
|
RefLinkVector &links = tree_links;
|
||||||
if (links.size() == 0 || count == 0) return kSuccess;
|
if (links.Size() == 0 || count == 0) return kSuccess;
|
||||||
// total size of message
|
// total size of message
|
||||||
const size_t total_size = type_nbytes * count;
|
const size_t total_size = type_nbytes * count;
|
||||||
// number of links
|
// number of links
|
||||||
const int nlink = static_cast<int>(links.size());
|
const int nlink = static_cast<int>(links.Size());
|
||||||
// send recv buffer
|
// send recv buffer
|
||||||
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
|
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
|
||||||
// size of space that we already performs reduce in up pass
|
// size of space that we already performs reduce in up pass
|
||||||
@ -581,8 +582,9 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
|||||||
|
|
||||||
// if max reduce is less than total size, we reduce multiple times of
|
// if max reduce is less than total size, we reduce multiple times of
|
||||||
// eachreduce size
|
// eachreduce size
|
||||||
if (max_reduce < total_size)
|
if (max_reduce < total_size) {
|
||||||
max_reduce = max_reduce - max_reduce % eachreduce;
|
max_reduce = max_reduce - max_reduce % eachreduce;
|
||||||
|
}
|
||||||
|
|
||||||
// peform reduce, can be at most two rounds
|
// peform reduce, can be at most two rounds
|
||||||
while (size_up_reduce < max_reduce) {
|
while (size_up_reduce < max_reduce) {
|
||||||
@ -677,11 +679,11 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
|||||||
AllreduceBase::ReturnType
|
AllreduceBase::ReturnType
|
||||||
AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||||
RefLinkVector &links = tree_links;
|
RefLinkVector &links = tree_links;
|
||||||
if (links.size() == 0 || total_size == 0) return kSuccess;
|
if (links.Size() == 0 || total_size == 0) return kSuccess;
|
||||||
utils::Check(root < world_size,
|
utils::Check(root < world_size,
|
||||||
"Broadcast: root should be smaller than world size");
|
"Broadcast: root should be smaller than world size");
|
||||||
// number of links
|
// number of links
|
||||||
const int nlink = static_cast<int>(links.size());
|
const int nlink = static_cast<int>(links.Size());
|
||||||
// size of space already read from data
|
// size of space already read from data
|
||||||
size_t size_in = 0;
|
size_t size_in = 0;
|
||||||
// input link, -2 means unknown yet, -1 means this is root
|
// input link, -2 means unknown yet, -1 means this is root
|
||||||
|
|||||||
@ -25,7 +25,7 @@
|
|||||||
#endif // RABIT_CXXTESTDEFS_H
|
#endif // RABIT_CXXTESTDEFS_H
|
||||||
|
|
||||||
|
|
||||||
namespace MPI {
|
namespace MPI { // NOLINT
|
||||||
// MPI data type to be compatible with existing MPI interface
|
// MPI data type to be compatible with existing MPI interface
|
||||||
class Datatype {
|
class Datatype {
|
||||||
public:
|
public:
|
||||||
@ -41,12 +41,12 @@ class AllreduceBase : public IEngine {
|
|||||||
// magic number to verify server
|
// magic number to verify server
|
||||||
static const int kMagic = 0xff99;
|
static const int kMagic = 0xff99;
|
||||||
// constant one byte out of band message to indicate error happening
|
// constant one byte out of band message to indicate error happening
|
||||||
AllreduceBase(void);
|
AllreduceBase();
|
||||||
virtual ~AllreduceBase(void) {}
|
virtual ~AllreduceBase() = default;
|
||||||
// initialize the manager
|
// initialize the manager
|
||||||
virtual bool Init(int argc, char* argv[]);
|
virtual bool Init(int argc, char* argv[]);
|
||||||
// shutdown the engine
|
// shutdown the engine
|
||||||
virtual bool Shutdown(void);
|
virtual bool Shutdown();
|
||||||
/*!
|
/*!
|
||||||
* \brief set parameters to the engine
|
* \brief set parameters to the engine
|
||||||
* \param name parameter name
|
* \param name parameter name
|
||||||
@ -59,27 +59,27 @@ class AllreduceBase : public IEngine {
|
|||||||
* the user who monitors the tracker
|
* the user who monitors the tracker
|
||||||
* \param msg message to be printed in the tracker
|
* \param msg message to be printed in the tracker
|
||||||
*/
|
*/
|
||||||
virtual void TrackerPrint(const std::string &msg);
|
void TrackerPrint(const std::string &msg) override;
|
||||||
|
|
||||||
/*! \brief get rank of previous node in ring topology*/
|
/*! \brief get rank of previous node in ring topology*/
|
||||||
virtual int GetRingPrevRank(void) const {
|
int GetRingPrevRank() const override {
|
||||||
return ring_prev->rank;
|
return ring_prev->rank;
|
||||||
}
|
}
|
||||||
/*! \brief get rank */
|
/*! \brief get rank */
|
||||||
virtual int GetRank(void) const {
|
int GetRank() const override {
|
||||||
return rank;
|
return rank;
|
||||||
}
|
}
|
||||||
/*! \brief get rank */
|
/*! \brief get rank */
|
||||||
virtual int GetWorldSize(void) const {
|
int GetWorldSize() const override {
|
||||||
if (world_size == -1) return 1;
|
if (world_size == -1) return 1;
|
||||||
return world_size;
|
return world_size;
|
||||||
}
|
}
|
||||||
/*! \brief whether is distributed or not */
|
/*! \brief whether is distributed or not */
|
||||||
virtual bool IsDistributed(void) const {
|
bool IsDistributed() const override {
|
||||||
return tracker_uri != "NULL";
|
return tracker_uri != "NULL";
|
||||||
}
|
}
|
||||||
/*! \brief get rank */
|
/*! \brief get rank */
|
||||||
virtual std::string GetHost(void) const {
|
std::string GetHost() const override {
|
||||||
return host_uri;
|
return host_uri;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,13 +99,10 @@ class AllreduceBase : public IEngine {
|
|||||||
* \param _line caller line number used to generate unique cache key
|
* \param _line caller line number used to generate unique cache key
|
||||||
* \param _caller caller function name used to generate unique cache key
|
* \param _caller caller function name used to generate unique cache key
|
||||||
*/
|
*/
|
||||||
virtual void Allgather(void *sendrecvbuf_, size_t total_size,
|
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
|
||||||
size_t slice_begin,
|
size_t slice_end, size_t size_prev_slice,
|
||||||
size_t slice_end,
|
const char *_file = _FILE, const int _line = _LINE,
|
||||||
size_t size_prev_slice,
|
const char *_caller = _CALLER) override {
|
||||||
const char* _file = _FILE,
|
|
||||||
const int _line = _LINE,
|
|
||||||
const char* _caller = _CALLER) {
|
|
||||||
if (world_size == 1 || world_size == -1) return;
|
if (world_size == 1 || world_size == -1) return;
|
||||||
utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size,
|
utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size,
|
||||||
slice_begin, slice_end, size_prev_slice) == kSuccess,
|
slice_begin, slice_end, size_prev_slice) == kSuccess,
|
||||||
@ -126,16 +123,12 @@ class AllreduceBase : public IEngine {
|
|||||||
* \param _line caller line number used to generate unique cache key
|
* \param _line caller line number used to generate unique cache key
|
||||||
* \param _caller caller function name used to generate unique cache key
|
* \param _caller caller function name used to generate unique cache key
|
||||||
*/
|
*/
|
||||||
virtual void Allreduce(void *sendrecvbuf_,
|
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
|
||||||
size_t type_nbytes,
|
ReduceFunction reducer, PreprocFunction prepare_fun = nullptr,
|
||||||
size_t count,
|
void *prepare_arg = nullptr, const char *_file = _FILE,
|
||||||
ReduceFunction reducer,
|
const int _line = _LINE,
|
||||||
PreprocFunction prepare_fun = NULL,
|
const char *_caller = _CALLER) override {
|
||||||
void *prepare_arg = NULL,
|
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
|
||||||
const char* _file = _FILE,
|
|
||||||
const int _line = _LINE,
|
|
||||||
const char* _caller = _CALLER) {
|
|
||||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
|
||||||
if (world_size == 1 || world_size == -1) return;
|
if (world_size == 1 || world_size == -1) return;
|
||||||
utils::Assert(TryAllreduce(sendrecvbuf_,
|
utils::Assert(TryAllreduce(sendrecvbuf_,
|
||||||
type_nbytes, count, reducer) == kSuccess,
|
type_nbytes, count, reducer) == kSuccess,
|
||||||
@ -150,8 +143,9 @@ class AllreduceBase : public IEngine {
|
|||||||
* \param _line caller line number used to generate unique cache key
|
* \param _line caller line number used to generate unique cache key
|
||||||
* \param _caller caller function name used to generate unique cache key
|
* \param _caller caller function name used to generate unique cache key
|
||||||
*/
|
*/
|
||||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||||
const char* _file = _FILE, const int _line = _LINE, const char* _caller = _CALLER) {
|
const char *_file = _FILE, const int _line = _LINE,
|
||||||
|
const char *_caller = _CALLER) override {
|
||||||
if (world_size == 1 || world_size == -1) return;
|
if (world_size == 1 || world_size == -1) return;
|
||||||
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
|
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
|
||||||
"Broadcast failed");
|
"Broadcast failed");
|
||||||
@ -178,8 +172,8 @@ class AllreduceBase : public IEngine {
|
|||||||
*
|
*
|
||||||
* \sa CheckPoint, VersionNumber
|
* \sa CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual int LoadCheckPoint(Serializable *global_model,
|
int LoadCheckPoint(Serializable *global_model,
|
||||||
Serializable *local_model = NULL) {
|
Serializable *local_model = nullptr) override {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -198,8 +192,8 @@ class AllreduceBase : public IEngine {
|
|||||||
*
|
*
|
||||||
* \sa LoadCheckPoint, VersionNumber
|
* \sa LoadCheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual void CheckPoint(const Serializable *global_model,
|
void CheckPoint(const Serializable *global_model,
|
||||||
const Serializable *local_model = NULL) {
|
const Serializable *local_model = nullptr) override {
|
||||||
version_number += 1;
|
version_number += 1;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -222,7 +216,7 @@ class AllreduceBase : public IEngine {
|
|||||||
* is the same in all nodes
|
* is the same in all nodes
|
||||||
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual void LazyCheckPoint(const Serializable *global_model) {
|
void LazyCheckPoint(const Serializable *global_model) override {
|
||||||
version_number += 1;
|
version_number += 1;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -230,7 +224,7 @@ class AllreduceBase : public IEngine {
|
|||||||
* which means how many calls to CheckPoint we made so far
|
* which means how many calls to CheckPoint we made so far
|
||||||
* \sa LoadCheckPoint, CheckPoint
|
* \sa LoadCheckPoint, CheckPoint
|
||||||
*/
|
*/
|
||||||
virtual int VersionNumber(void) const {
|
int VersionNumber() const override {
|
||||||
return version_number;
|
return version_number;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -238,14 +232,14 @@ class AllreduceBase : public IEngine {
|
|||||||
* call this function when IEngine throw an exception out,
|
* call this function when IEngine throw an exception out,
|
||||||
* this function is only used for test purpose
|
* this function is only used for test purpose
|
||||||
*/
|
*/
|
||||||
virtual void InitAfterException(void) {
|
void InitAfterException() override {
|
||||||
utils::Error("InitAfterException: not implemented");
|
utils::Error("InitAfterException: not implemented");
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief report current status to the job tracker
|
* \brief report current status to the job tracker
|
||||||
* depending on the job tracker we are in
|
* depending on the job tracker we are in
|
||||||
*/
|
*/
|
||||||
inline void ReportStatus(void) const {
|
inline void ReportStatus() const {
|
||||||
if (hadoop_mode != 0) {
|
if (hadoop_mode != 0) {
|
||||||
fprintf(stderr, "reporter:status:Rabit Phase[%03d] Operation %03d\n",
|
fprintf(stderr, "reporter:status:Rabit Phase[%03d] Operation %03d\n",
|
||||||
version_number, seq_counter);
|
version_number, seq_counter);
|
||||||
@ -274,7 +268,7 @@ class AllreduceBase : public IEngine {
|
|||||||
/*! \brief internal return type */
|
/*! \brief internal return type */
|
||||||
ReturnTypeEnum value;
|
ReturnTypeEnum value;
|
||||||
// constructor
|
// constructor
|
||||||
ReturnType() {}
|
ReturnType() = default;
|
||||||
ReturnType(ReturnTypeEnum value) : value(value) {} // NOLINT(*)
|
ReturnType(ReturnTypeEnum value) : value(value) {} // NOLINT(*)
|
||||||
inline bool operator==(const ReturnTypeEnum &v) const {
|
inline bool operator==(const ReturnTypeEnum &v) const {
|
||||||
return value == v;
|
return value == v;
|
||||||
@ -306,13 +300,11 @@ class AllreduceBase : public IEngine {
|
|||||||
// size of data sent to the link
|
// size of data sent to the link
|
||||||
size_t size_write;
|
size_t size_write;
|
||||||
// pointer to buffer head
|
// pointer to buffer head
|
||||||
char *buffer_head;
|
char *buffer_head {nullptr};
|
||||||
// buffer size, in bytes
|
// buffer size, in bytes
|
||||||
size_t buffer_size;
|
size_t buffer_size {0};
|
||||||
// constructor
|
// constructor
|
||||||
LinkRecord(void)
|
LinkRecord() = default;
|
||||||
: buffer_head(NULL), buffer_size(0) {
|
|
||||||
}
|
|
||||||
// initialize buffer
|
// initialize buffer
|
||||||
inline void InitBuffer(size_t type_nbytes, size_t count,
|
inline void InitBuffer(size_t type_nbytes, size_t count,
|
||||||
size_t reduce_buffer_size) {
|
size_t reduce_buffer_size) {
|
||||||
@ -328,7 +320,7 @@ class AllreduceBase : public IEngine {
|
|||||||
buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
|
buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
|
||||||
}
|
}
|
||||||
// reset the recv and sent size
|
// reset the recv and sent size
|
||||||
inline void ResetSize(void) {
|
inline void ResetSize() {
|
||||||
size_write = size_read = 0;
|
size_write = size_read = 0;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -340,7 +332,7 @@ class AllreduceBase : public IEngine {
|
|||||||
* \return the type of reading
|
* \return the type of reading
|
||||||
*/
|
*/
|
||||||
inline ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read) {
|
inline ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read) {
|
||||||
utils::Assert(buffer_head != NULL, "ReadToRingBuffer: buffer not allocated");
|
utils::Assert(buffer_head != nullptr, "ReadToRingBuffer: buffer not allocated");
|
||||||
utils::Assert(size_read <= max_size_read, "ReadToRingBuffer: max_size_read check");
|
utils::Assert(size_read <= max_size_read, "ReadToRingBuffer: max_size_read check");
|
||||||
size_t ngap = size_read - protect_start;
|
size_t ngap = size_read - protect_start;
|
||||||
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
|
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
|
||||||
@ -405,7 +397,7 @@ class AllreduceBase : public IEngine {
|
|||||||
inline LinkRecord &operator[](size_t i) {
|
inline LinkRecord &operator[](size_t i) {
|
||||||
return *plinks[i];
|
return *plinks[i];
|
||||||
}
|
}
|
||||||
inline size_t size(void) const {
|
inline size_t Size() const {
|
||||||
return plinks.size();
|
return plinks.size();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -413,7 +405,7 @@ class AllreduceBase : public IEngine {
|
|||||||
* \brief initialize connection to the tracker
|
* \brief initialize connection to the tracker
|
||||||
* \return a socket that initializes the connection
|
* \return a socket that initializes the connection
|
||||||
*/
|
*/
|
||||||
utils::TCPSocket ConnectTracker(void) const;
|
utils::TCPSocket ConnectTracker() const;
|
||||||
/*!
|
/*!
|
||||||
* \brief connect to the tracker to fix the the missing links
|
* \brief connect to the tracker to fix the the missing links
|
||||||
* this function is also used when the engine start up
|
* this function is also used when the engine start up
|
||||||
@ -525,64 +517,64 @@ class AllreduceBase : public IEngine {
|
|||||||
//---- data structure related to model ----
|
//---- data structure related to model ----
|
||||||
// call sequence counter, records how many calls we made so far
|
// call sequence counter, records how many calls we made so far
|
||||||
// from last call to CheckPoint, LoadCheckPoint
|
// from last call to CheckPoint, LoadCheckPoint
|
||||||
int seq_counter;
|
int seq_counter; // NOLINT
|
||||||
// version number of model
|
// version number of model
|
||||||
int version_number;
|
int version_number; // NOLINT
|
||||||
// whether the job is running in hadoop
|
// whether the job is running in hadoop
|
||||||
bool hadoop_mode;
|
bool hadoop_mode; // NOLINT
|
||||||
//---- local data related to link ----
|
//---- local data related to link ----
|
||||||
// index of parent link, can be -1, meaning this is root of the tree
|
// index of parent link, can be -1, meaning this is root of the tree
|
||||||
int parent_index;
|
int parent_index; // NOLINT
|
||||||
// rank of parent node, can be -1
|
// rank of parent node, can be -1
|
||||||
int parent_rank;
|
int parent_rank; // NOLINT
|
||||||
// sockets of all links this connects to
|
// sockets of all links this connects to
|
||||||
std::vector<LinkRecord> all_links;
|
std::vector<LinkRecord> all_links; // NOLINT
|
||||||
// used to record the link where things goes wrong
|
// used to record the link where things goes wrong
|
||||||
LinkRecord *err_link;
|
LinkRecord *err_link; // NOLINT
|
||||||
// all the links in the reduction tree connection
|
// all the links in the reduction tree connection
|
||||||
RefLinkVector tree_links;
|
RefLinkVector tree_links; // NOLINT
|
||||||
// pointer to links in the ring
|
// pointer to links in the ring
|
||||||
LinkRecord *ring_prev, *ring_next;
|
LinkRecord *ring_prev, *ring_next; // NOLINT
|
||||||
//----- meta information-----
|
//----- meta information-----
|
||||||
// list of enviroment variables that are of possible interest
|
// list of enviroment variables that are of possible interest
|
||||||
std::vector<std::string> env_vars;
|
std::vector<std::string> env_vars; // NOLINT
|
||||||
// unique identifier of the possible job this process is doing
|
// unique identifier of the possible job this process is doing
|
||||||
// used to assign ranks, optional, default to NULL
|
// used to assign ranks, optional, default to NULL
|
||||||
std::string task_id;
|
std::string task_id; // NOLINT
|
||||||
// uri of current host, to be set by Init
|
// uri of current host, to be set by Init
|
||||||
std::string host_uri;
|
std::string host_uri; // NOLINT
|
||||||
// uri of tracker
|
// uri of tracker
|
||||||
std::string tracker_uri;
|
std::string tracker_uri; // NOLINT
|
||||||
// role in dmlc jobs
|
// role in dmlc jobs
|
||||||
std::string dmlc_role;
|
std::string dmlc_role; // NOLINT
|
||||||
// port of tracker address
|
// port of tracker address
|
||||||
int tracker_port;
|
int tracker_port; // NOLINT
|
||||||
// port of slave process
|
// port of slave process
|
||||||
int slave_port, nport_trial;
|
int slave_port, nport_trial; // NOLINT
|
||||||
// reduce buffer size
|
// reduce buffer size
|
||||||
size_t reduce_buffer_size;
|
size_t reduce_buffer_size; // NOLINT
|
||||||
// reduction method
|
// reduction method
|
||||||
int reduce_method;
|
int reduce_method; // NOLINT
|
||||||
// mininum count of cells to use ring based method
|
// mininum count of cells to use ring based method
|
||||||
size_t reduce_ring_mincount;
|
size_t reduce_ring_mincount; // NOLINT
|
||||||
// minimul block size per tree reduce
|
// minimul block size per tree reduce
|
||||||
size_t tree_reduce_minsize;
|
size_t tree_reduce_minsize; // NOLINT
|
||||||
// current rank
|
// current rank
|
||||||
int rank;
|
int rank; // NOLINT
|
||||||
// world size
|
// world size
|
||||||
int world_size;
|
int world_size; // NOLINT
|
||||||
// connect retry time
|
// connect retry time
|
||||||
int connect_retry;
|
int connect_retry; // NOLINT
|
||||||
// enable bootstrap cache 0 false 1 true
|
// enable bootstrap cache 0 false 1 true
|
||||||
bool rabit_bootstrap_cache = false;
|
bool rabit_bootstrap_cache = false; // NOLINT
|
||||||
// enable detailed logging
|
// enable detailed logging
|
||||||
bool rabit_debug = false;
|
bool rabit_debug = false; // NOLINT
|
||||||
// by default, if rabit worker not recover in half an hour exit
|
// by default, if rabit worker not recover in half an hour exit
|
||||||
int timeout_sec = 1800;
|
int timeout_sec = 1800; // NOLINT
|
||||||
// flag to enable rabit_timeout
|
// flag to enable rabit_timeout
|
||||||
bool rabit_timeout = false;
|
bool rabit_timeout = false; // NOLINT
|
||||||
// Enable TCP node delay
|
// Enable TCP node delay
|
||||||
bool rabit_enable_tcp_no_delay = false;
|
bool rabit_enable_tcp_no_delay = false; // NOLINT
|
||||||
};
|
};
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|||||||
@ -20,74 +20,65 @@ namespace engine {
|
|||||||
class AllreduceMock : public AllreduceRobust {
|
class AllreduceMock : public AllreduceRobust {
|
||||||
public:
|
public:
|
||||||
// constructor
|
// constructor
|
||||||
AllreduceMock(void) {
|
AllreduceMock() {
|
||||||
num_trial = 0;
|
num_trial_ = 0;
|
||||||
force_local = 0;
|
force_local_ = 0;
|
||||||
report_stats = 0;
|
report_stats_ = 0;
|
||||||
tsum_allreduce = 0.0;
|
tsum_allreduce_ = 0.0;
|
||||||
tsum_allgather = 0.0;
|
tsum_allgather_ = 0.0;
|
||||||
}
|
}
|
||||||
// destructor
|
// destructor
|
||||||
virtual ~AllreduceMock(void) {}
|
~AllreduceMock() override = default;
|
||||||
virtual void SetParam(const char *name, const char *val) {
|
void SetParam(const char *name, const char *val) override {
|
||||||
AllreduceRobust::SetParam(name, val);
|
AllreduceRobust::SetParam(name, val);
|
||||||
// additional parameters
|
// additional parameters
|
||||||
if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val);
|
if (!strcmp(name, "rabit_num_trial")) num_trial_ = atoi(val);
|
||||||
if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial = atoi(val);
|
if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial_ = atoi(val);
|
||||||
if (!strcmp(name, "report_stats")) report_stats = atoi(val);
|
if (!strcmp(name, "report_stats")) report_stats_ = atoi(val);
|
||||||
if (!strcmp(name, "force_local")) force_local = atoi(val);
|
if (!strcmp(name, "force_local")) force_local_ = atoi(val);
|
||||||
if (!strcmp(name, "mock")) {
|
if (!strcmp(name, "mock")) {
|
||||||
MockKey k;
|
MockKey k;
|
||||||
utils::Check(sscanf(val, "%d,%d,%d,%d",
|
utils::Check(sscanf(val, "%d,%d,%d,%d",
|
||||||
&k.rank, &k.version, &k.seqno, &k.ntrial) == 4,
|
&k.rank, &k.version, &k.seqno, &k.ntrial) == 4,
|
||||||
"invalid mock parameter");
|
"invalid mock parameter");
|
||||||
mock_map[k] = 1;
|
mock_map_[k] = 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
virtual void Allreduce(void *sendrecvbuf_,
|
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
|
||||||
size_t type_nbytes,
|
ReduceFunction reducer, PreprocFunction prepare_fun,
|
||||||
size_t count,
|
void *prepare_arg, const char *_file = _FILE,
|
||||||
ReduceFunction reducer,
|
const int _line = _LINE,
|
||||||
PreprocFunction prepare_fun,
|
const char *_caller = _CALLER) override {
|
||||||
void *prepare_arg,
|
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "AllReduce");
|
||||||
const char* _file = _FILE,
|
|
||||||
const int _line = _LINE,
|
|
||||||
const char* _caller = _CALLER) {
|
|
||||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
|
|
||||||
double tstart = utils::GetTime();
|
double tstart = utils::GetTime();
|
||||||
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
|
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
|
||||||
count, reducer, prepare_fun, prepare_arg,
|
count, reducer, prepare_fun, prepare_arg,
|
||||||
_file, _line, _caller);
|
_file, _line, _caller);
|
||||||
tsum_allreduce += utils::GetTime() - tstart;
|
tsum_allreduce_ += utils::GetTime() - tstart;
|
||||||
}
|
}
|
||||||
virtual void Allgather(void *sendrecvbuf,
|
void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin,
|
||||||
size_t total_size,
|
size_t slice_end, size_t size_prev_slice,
|
||||||
size_t slice_begin,
|
const char *_file = _FILE, const int _line = _LINE,
|
||||||
size_t slice_end,
|
const char *_caller = _CALLER) override {
|
||||||
size_t size_prev_slice,
|
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Allgather");
|
||||||
const char* _file = _FILE,
|
|
||||||
const int _line = _LINE,
|
|
||||||
const char* _caller = _CALLER) {
|
|
||||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Allgather");
|
|
||||||
double tstart = utils::GetTime();
|
double tstart = utils::GetTime();
|
||||||
AllreduceRobust::Allgather(sendrecvbuf, total_size,
|
AllreduceRobust::Allgather(sendrecvbuf, total_size,
|
||||||
slice_begin, slice_end,
|
slice_begin, slice_end,
|
||||||
size_prev_slice, _file, _line, _caller);
|
size_prev_slice, _file, _line, _caller);
|
||||||
tsum_allgather += utils::GetTime() - tstart;
|
tsum_allgather_ += utils::GetTime() - tstart;
|
||||||
}
|
}
|
||||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||||
const char* _file = _FILE,
|
const char *_file = _FILE, const int _line = _LINE,
|
||||||
const int _line = _LINE,
|
const char *_caller = _CALLER) override {
|
||||||
const char* _caller = _CALLER) {
|
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Broadcast");
|
||||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
|
|
||||||
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller);
|
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller);
|
||||||
}
|
}
|
||||||
virtual int LoadCheckPoint(Serializable *global_model,
|
int LoadCheckPoint(Serializable *global_model,
|
||||||
Serializable *local_model) {
|
Serializable *local_model) override {
|
||||||
tsum_allreduce = 0.0;
|
tsum_allreduce_ = 0.0;
|
||||||
tsum_allgather = 0.0;
|
tsum_allgather_ = 0.0;
|
||||||
time_checkpoint = utils::GetTime();
|
time_checkpoint_ = utils::GetTime();
|
||||||
if (force_local == 0) {
|
if (force_local_ == 0) {
|
||||||
return AllreduceRobust::LoadCheckPoint(global_model, local_model);
|
return AllreduceRobust::LoadCheckPoint(global_model, local_model);
|
||||||
} else {
|
} else {
|
||||||
DummySerializer dum;
|
DummySerializer dum;
|
||||||
@ -95,56 +86,54 @@ class AllreduceMock : public AllreduceRobust {
|
|||||||
return AllreduceRobust::LoadCheckPoint(&dum, &com);
|
return AllreduceRobust::LoadCheckPoint(&dum, &com);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
virtual void CheckPoint(const Serializable *global_model,
|
void CheckPoint(const Serializable *global_model,
|
||||||
const Serializable *local_model) {
|
const Serializable *local_model) override {
|
||||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint");
|
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "CheckPoint");
|
||||||
double tstart = utils::GetTime();
|
double tstart = utils::GetTime();
|
||||||
double tbet_chkpt = tstart - time_checkpoint;
|
double tbet_chkpt = tstart - time_checkpoint_;
|
||||||
if (force_local == 0) {
|
if (force_local_ == 0) {
|
||||||
AllreduceRobust::CheckPoint(global_model, local_model);
|
AllreduceRobust::CheckPoint(global_model, local_model);
|
||||||
} else {
|
} else {
|
||||||
DummySerializer dum;
|
DummySerializer dum;
|
||||||
ComboSerializer com(global_model, local_model);
|
ComboSerializer com(global_model, local_model);
|
||||||
AllreduceRobust::CheckPoint(&dum, &com);
|
AllreduceRobust::CheckPoint(&dum, &com);
|
||||||
}
|
}
|
||||||
time_checkpoint = utils::GetTime();
|
time_checkpoint_ = utils::GetTime();
|
||||||
double tcost = utils::GetTime() - tstart;
|
double tcost = utils::GetTime() - tstart;
|
||||||
if (report_stats != 0 && rank == 0) {
|
if (report_stats_ != 0 && rank == 0) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "[v" << version_number << "] global_size=" << global_checkpoint.length()
|
ss << "[v" << version_number << "] global_size=" << global_checkpoint_.length()
|
||||||
<< ",local_size=" << (local_chkpt[0].length() + local_chkpt[1].length())
|
<< ",local_size=" << (local_chkpt_[0].length() + local_chkpt_[1].length())
|
||||||
<< ",check_tcost="<< tcost <<" sec"
|
<< ",check_tcost="<< tcost <<" sec"
|
||||||
<< ",allreduce_tcost=" << tsum_allreduce << " sec"
|
<< ",allreduce_tcost=" << tsum_allreduce_ << " sec"
|
||||||
<< ",allgather_tcost=" << tsum_allgather << " sec"
|
<< ",allgather_tcost=" << tsum_allgather_ << " sec"
|
||||||
<< ",between_chpt=" << tbet_chkpt << "sec\n";
|
<< ",between_chpt=" << tbet_chkpt << "sec\n";
|
||||||
this->TrackerPrint(ss.str());
|
this->TrackerPrint(ss.str());
|
||||||
}
|
}
|
||||||
tsum_allreduce = 0.0;
|
tsum_allreduce_ = 0.0;
|
||||||
tsum_allgather = 0.0;
|
tsum_allgather_ = 0.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual void LazyCheckPoint(const Serializable *global_model) {
|
void LazyCheckPoint(const Serializable *global_model) override {
|
||||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint");
|
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "LazyCheckPoint");
|
||||||
AllreduceRobust::LazyCheckPoint(global_model);
|
AllreduceRobust::LazyCheckPoint(global_model);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// force checkpoint to local
|
// force checkpoint to local
|
||||||
int force_local;
|
int force_local_;
|
||||||
// whether report statistics
|
// whether report statistics
|
||||||
int report_stats;
|
int report_stats_;
|
||||||
// sum of allreduce
|
// sum of allreduce
|
||||||
double tsum_allreduce;
|
double tsum_allreduce_;
|
||||||
// sum of allgather
|
// sum of allgather
|
||||||
double tsum_allgather;
|
double tsum_allgather_;
|
||||||
double time_checkpoint;
|
double time_checkpoint_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct DummySerializer : public Serializable {
|
struct DummySerializer : public Serializable {
|
||||||
virtual void Load(Stream *fi) {
|
void Load(Stream *fi) override {}
|
||||||
}
|
void Save(Stream *fo) const override {}
|
||||||
virtual void Save(Stream *fo) const {
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
struct ComboSerializer : public Serializable {
|
struct ComboSerializer : public Serializable {
|
||||||
Serializable *lhs;
|
Serializable *lhs;
|
||||||
@ -155,15 +144,15 @@ class AllreduceMock : public AllreduceRobust {
|
|||||||
: lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) {
|
: lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) {
|
||||||
}
|
}
|
||||||
ComboSerializer(const Serializable *lhs, const Serializable *rhs)
|
ComboSerializer(const Serializable *lhs, const Serializable *rhs)
|
||||||
: lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) {
|
: lhs(nullptr), rhs(nullptr), c_lhs(lhs), c_rhs(rhs) {
|
||||||
}
|
}
|
||||||
virtual void Load(Stream *fi) {
|
void Load(Stream *fi) override {
|
||||||
if (lhs != NULL) lhs->Load(fi);
|
if (lhs != nullptr) lhs->Load(fi);
|
||||||
if (rhs != NULL) rhs->Load(fi);
|
if (rhs != nullptr) rhs->Load(fi);
|
||||||
}
|
}
|
||||||
virtual void Save(Stream *fo) const {
|
void Save(Stream *fo) const override {
|
||||||
if (c_lhs != NULL) c_lhs->Save(fo);
|
if (c_lhs != nullptr) c_lhs->Save(fo);
|
||||||
if (c_rhs != NULL) c_rhs->Save(fo);
|
if (c_rhs != nullptr) c_rhs->Save(fo);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
// key to identify the mock stage
|
// key to identify the mock stage
|
||||||
@ -172,7 +161,7 @@ class AllreduceMock : public AllreduceRobust {
|
|||||||
int version;
|
int version;
|
||||||
int seqno;
|
int seqno;
|
||||||
int ntrial;
|
int ntrial;
|
||||||
MockKey(void) {}
|
MockKey() = default;
|
||||||
MockKey(int rank, int version, int seqno, int ntrial)
|
MockKey(int rank, int version, int seqno, int ntrial)
|
||||||
: rank(rank), version(version), seqno(seqno), ntrial(ntrial) {}
|
: rank(rank), version(version), seqno(seqno), ntrial(ntrial) {}
|
||||||
inline bool operator==(const MockKey &b) const {
|
inline bool operator==(const MockKey &b) const {
|
||||||
@ -189,15 +178,15 @@ class AllreduceMock : public AllreduceRobust {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
// number of failure trials
|
// number of failure trials
|
||||||
int num_trial;
|
int num_trial_;
|
||||||
// record all mock actions
|
// record all mock actions
|
||||||
std::map<MockKey, int> mock_map;
|
std::map<MockKey, int> mock_map_;
|
||||||
// used to generate all kinds of exceptions
|
// used to generate all kinds of exceptions
|
||||||
inline void Verify(const MockKey &key, const char *name) {
|
inline void Verify(const MockKey &key, const char *name) {
|
||||||
if (mock_map.count(key) != 0) {
|
if (mock_map_.count(key) != 0) {
|
||||||
num_trial += 1;
|
num_trial_ += 1;
|
||||||
// data processing frameworks runs on shared process
|
// data processing frameworks runs on shared process
|
||||||
_error("[%d]@@@Hit Mock Error:%s ", rank, name);
|
error_("[%d]@@@Hit Mock Error:%s ", rank, name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -40,9 +40,9 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
|
|||||||
const std::vector<EdgeType> &edge_in,
|
const std::vector<EdgeType> &edge_in,
|
||||||
size_t out_index)) {
|
size_t out_index)) {
|
||||||
RefLinkVector &links = tree_links;
|
RefLinkVector &links = tree_links;
|
||||||
if (links.size() == 0) return kSuccess;
|
if (links.Size() == 0) return kSuccess;
|
||||||
// number of links
|
// number of links
|
||||||
const int nlink = static_cast<int>(links.size());
|
const int nlink = static_cast<int>(links.Size());
|
||||||
// initialize the pointers
|
// initialize the pointers
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
links[i].ResetSize();
|
links[i].ResetSize();
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -22,18 +22,18 @@ namespace engine {
|
|||||||
/*! \brief implementation of fault tolerant all reduce engine */
|
/*! \brief implementation of fault tolerant all reduce engine */
|
||||||
class AllreduceRobust : public AllreduceBase {
|
class AllreduceRobust : public AllreduceBase {
|
||||||
public:
|
public:
|
||||||
AllreduceRobust(void);
|
AllreduceRobust();
|
||||||
virtual ~AllreduceRobust(void) {}
|
~AllreduceRobust() override = default;
|
||||||
// initialize the manager
|
// initialize the manager
|
||||||
virtual bool Init(int argc, char* argv[]);
|
bool Init(int argc, char* argv[]) override;
|
||||||
/*! \brief shutdown the engine */
|
/*! \brief shutdown the engine */
|
||||||
virtual bool Shutdown(void);
|
bool Shutdown() override;
|
||||||
/*!
|
/*!
|
||||||
* \brief set parameters to the engine
|
* \brief set parameters to the engine
|
||||||
* \param name parameter name
|
* \param name parameter name
|
||||||
* \param val parameter value
|
* \param val parameter value
|
||||||
*/
|
*/
|
||||||
virtual void SetParam(const char *name, const char *val);
|
void SetParam(const char *name, const char *val) override;
|
||||||
/*!
|
/*!
|
||||||
* \brief perform immutable local bootstrap cache insertion
|
* \brief perform immutable local bootstrap cache insertion
|
||||||
* \param key unique cache key
|
* \param key unique cache key
|
||||||
@ -67,13 +67,10 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
* \param _line caller line number used to generate unique cache key
|
* \param _line caller line number used to generate unique cache key
|
||||||
* \param _caller caller function name used to generate unique cache key
|
* \param _caller caller function name used to generate unique cache key
|
||||||
*/
|
*/
|
||||||
virtual void Allgather(void *sendrecvbuf_, size_t total_size,
|
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
|
||||||
size_t slice_begin,
|
size_t slice_end, size_t size_prev_slice,
|
||||||
size_t slice_end,
|
const char *_file = _FILE, const int _line = _LINE,
|
||||||
size_t size_prev_slice,
|
const char *_caller = _CALLER) override;
|
||||||
const char* _file = _FILE,
|
|
||||||
const int _line = _LINE,
|
|
||||||
const char* _caller = _CALLER);
|
|
||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf
|
* \brief perform in-place allreduce, on sendrecvbuf
|
||||||
* this function is NOT thread-safe
|
* this function is NOT thread-safe
|
||||||
@ -90,15 +87,11 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
* \param _line caller line number used to generate unique cache key
|
* \param _line caller line number used to generate unique cache key
|
||||||
* \param _caller caller function name used to generate unique cache key
|
* \param _caller caller function name used to generate unique cache key
|
||||||
*/
|
*/
|
||||||
virtual void Allreduce(void *sendrecvbuf_,
|
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
|
||||||
size_t type_nbytes,
|
ReduceFunction reducer, PreprocFunction prepare_fun = nullptr,
|
||||||
size_t count,
|
void *prepare_arg = nullptr, const char *_file = _FILE,
|
||||||
ReduceFunction reducer,
|
const int _line = _LINE,
|
||||||
PreprocFunction prepare_fun = NULL,
|
const char *_caller = _CALLER) override;
|
||||||
void *prepare_arg = NULL,
|
|
||||||
const char* _file = _FILE,
|
|
||||||
const int _line = _LINE,
|
|
||||||
const char* _caller = _CALLER);
|
|
||||||
/*!
|
/*!
|
||||||
* \brief broadcast data from root to all nodes
|
* \brief broadcast data from root to all nodes
|
||||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
@ -108,10 +101,9 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
* \param _line caller line number used to generate unique cache key
|
* \param _line caller line number used to generate unique cache key
|
||||||
* \param _caller caller function name used to generate unique cache key
|
* \param _caller caller function name used to generate unique cache key
|
||||||
*/
|
*/
|
||||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||||
const char* _file = _FILE,
|
const char *_file = _FILE, const int _line = _LINE,
|
||||||
const int _line = _LINE,
|
const char *_caller = _CALLER) override;
|
||||||
const char* _caller = _CALLER);
|
|
||||||
/*!
|
/*!
|
||||||
* \brief load latest check point
|
* \brief load latest check point
|
||||||
* \param global_model pointer to the globally shared model/state
|
* \param global_model pointer to the globally shared model/state
|
||||||
@ -134,8 +126,8 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
*
|
*
|
||||||
* \sa CheckPoint, VersionNumber
|
* \sa CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual int LoadCheckPoint(Serializable *global_model,
|
int LoadCheckPoint(Serializable *global_model,
|
||||||
Serializable *local_model = NULL);
|
Serializable *local_model = nullptr) override;
|
||||||
/*!
|
/*!
|
||||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||||
* every time we call check point, there is a version number which will increase by one
|
* every time we call check point, there is a version number which will increase by one
|
||||||
@ -152,9 +144,9 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
*
|
*
|
||||||
* \sa LoadCheckPoint, VersionNumber
|
* \sa LoadCheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual void CheckPoint(const Serializable *global_model,
|
void CheckPoint(const Serializable *global_model,
|
||||||
const Serializable *local_model = NULL) {
|
const Serializable *local_model = nullptr) override {
|
||||||
this->CheckPoint_(global_model, local_model, false);
|
this->CheckPointImpl(global_model, local_model, false);
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief This function can be used to replace CheckPoint for global_model only,
|
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||||
@ -176,18 +168,20 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
* is the same in all nodes
|
* is the same in all nodes
|
||||||
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual void LazyCheckPoint(const Serializable *global_model) {
|
void LazyCheckPoint(const Serializable *global_model) override {
|
||||||
this->CheckPoint_(global_model, NULL, true);
|
this->CheckPointImpl(global_model, nullptr, true);
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief explicitly re-init everything before calling LoadCheckPoint
|
* \brief explicitly re-init everything before calling LoadCheckPoint
|
||||||
* call this function when IEngine throw an exception out,
|
* call this function when IEngine throw an exception out,
|
||||||
* this function is only used for test purpose
|
* this function is only used for test purpose
|
||||||
*/
|
*/
|
||||||
virtual void InitAfterException(void) {
|
void InitAfterException() override {
|
||||||
// simple way, shutdown all links
|
// simple way, shutdown all links
|
||||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
for (auto& link : all_links) {
|
||||||
if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close();
|
if (link.sock.BadSocket()) {
|
||||||
|
link.sock.Close();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
ReConnectLinks("recover");
|
ReConnectLinks("recover");
|
||||||
}
|
}
|
||||||
@ -245,69 +239,69 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
// there are nodes request load cache
|
// there are nodes request load cache
|
||||||
static const int kLoadBootstrapCache = 16;
|
static const int kLoadBootstrapCache = 16;
|
||||||
// constructor
|
// constructor
|
||||||
ActionSummary(void) {}
|
ActionSummary() = default;
|
||||||
// constructor of action
|
// constructor of action
|
||||||
explicit ActionSummary(int seqno_flag, int cache_flag = 0,
|
explicit ActionSummary(int seqno_flag, int cache_flag = 0,
|
||||||
u_int32_t minseqno = kSpecialOp, u_int32_t maxseqno = kSpecialOp) {
|
u_int32_t minseqno = kSpecialOp, u_int32_t maxseqno = kSpecialOp) {
|
||||||
seqcode = (minseqno << 5) | seqno_flag;
|
seqcode_ = (minseqno << 5) | seqno_flag;
|
||||||
maxseqcode = (maxseqno << 5) | cache_flag;
|
maxseqcode_ = (maxseqno << 5) | cache_flag;
|
||||||
}
|
}
|
||||||
// minimum number of all operations by default
|
// minimum number of all operations by default
|
||||||
// maximum number of all cache operations otherwise
|
// maximum number of all cache operations otherwise
|
||||||
inline u_int32_t seqno(SeqType t = SeqType::kSeq) const {
|
inline u_int32_t Seqno(SeqType t = SeqType::kSeq) const {
|
||||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
|
||||||
return code >> 5;
|
return code >> 5;
|
||||||
}
|
}
|
||||||
// whether the operation set contains a load_check
|
// whether the operation set contains a load_check
|
||||||
inline bool load_check(SeqType t = SeqType::kSeq) const {
|
inline bool LoadCheck(SeqType t = SeqType::kSeq) const {
|
||||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
|
||||||
return (code & kLoadCheck) != 0;
|
return (code & kLoadCheck) != 0;
|
||||||
}
|
}
|
||||||
// whether the operation set contains a load_cache
|
// whether the operation set contains a load_cache
|
||||||
inline bool load_cache(SeqType t = SeqType::kSeq) const {
|
inline bool LoadCache(SeqType t = SeqType::kSeq) const {
|
||||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
|
||||||
return (code & kLoadBootstrapCache) != 0;
|
return (code & kLoadBootstrapCache) != 0;
|
||||||
}
|
}
|
||||||
// whether the operation set contains a check point
|
// whether the operation set contains a check point
|
||||||
inline bool check_point(SeqType t = SeqType::kSeq) const {
|
inline bool CheckPoint(SeqType t = SeqType::kSeq) const {
|
||||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
|
||||||
return (code & kCheckPoint) != 0;
|
return (code & kCheckPoint) != 0;
|
||||||
}
|
}
|
||||||
// whether the operation set contains a check ack
|
// whether the operation set contains a check ack
|
||||||
inline bool check_ack(SeqType t = SeqType::kSeq) const {
|
inline bool CheckAck(SeqType t = SeqType::kSeq) const {
|
||||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
|
||||||
return (code & kCheckAck) != 0;
|
return (code & kCheckAck) != 0;
|
||||||
}
|
}
|
||||||
// whether the operation set contains different sequence number
|
// whether the operation set contains different sequence number
|
||||||
inline bool diff_seq() const {
|
inline bool DiffSeq() const {
|
||||||
return (seqcode & kDiffSeq) != 0;
|
return (seqcode_ & kDiffSeq) != 0;
|
||||||
}
|
}
|
||||||
// returns the operation flag of the result
|
// returns the operation flag of the result
|
||||||
inline int flag(SeqType t = SeqType::kSeq) const {
|
inline int Flag(SeqType t = SeqType::kSeq) const {
|
||||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
|
||||||
return code & 31;
|
return code & 31;
|
||||||
}
|
}
|
||||||
// print flags in user friendly way
|
// print flags in user friendly way
|
||||||
inline void print_flags(int rank, std::string prefix ) {
|
inline void PrintFlags(int rank, std::string prefix ) {
|
||||||
utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n",
|
utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n", rank,
|
||||||
rank, prefix.c_str(),
|
prefix.c_str(), Seqno(), CheckPoint(), CheckAck(),
|
||||||
seqno(), check_point(), check_ack(), load_cache(),
|
LoadCache(), DiffSeq(), Seqno(SeqType::kCache),
|
||||||
diff_seq(), seqno(SeqType::kCache), load_cache(SeqType::kCache));
|
LoadCache(SeqType::kCache));
|
||||||
}
|
}
|
||||||
// reducer for Allreduce, get the result ActionSummary from all nodes
|
// reducer for Allreduce, get the result ActionSummary from all nodes
|
||||||
inline static void Reducer(const void *src_, void *dst_,
|
inline static void Reducer(const void *src_, void *dst_,
|
||||||
int len, const MPI::Datatype &dtype) {
|
int len, const MPI::Datatype &dtype) {
|
||||||
const ActionSummary *src = (const ActionSummary*)src_;
|
const ActionSummary *src = static_cast<const ActionSummary*>(src_);
|
||||||
ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_);
|
ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_);
|
||||||
for (int i = 0; i < len; ++i) {
|
for (int i = 0; i < len; ++i) {
|
||||||
u_int32_t min_seqno = std::min(src[i].seqno(), dst[i].seqno());
|
u_int32_t min_seqno = std::min(src[i].Seqno(), dst[i].Seqno());
|
||||||
u_int32_t max_seqno = std::max(src[i].seqno(SeqType::kCache),
|
u_int32_t max_seqno = std::max(src[i].Seqno(SeqType::kCache),
|
||||||
dst[i].seqno(SeqType::kCache));
|
dst[i].Seqno(SeqType::kCache));
|
||||||
int action_flag = src[i].flag() | dst[i].flag();
|
int action_flag = src[i].Flag() | dst[i].Flag();
|
||||||
// if any node is not requester set to 0 otherwise 1
|
// if any node is not requester set to 0 otherwise 1
|
||||||
int role_flag = src[i].flag(SeqType::kCache) & dst[i].flag(SeqType::kCache);
|
int role_flag = src[i].Flag(SeqType::kCache) & dst[i].Flag(SeqType::kCache);
|
||||||
// if seqno is different in src and destination
|
// if seqno is different in src and destination
|
||||||
int seq_diff_flag = src[i].seqno() != dst[i].seqno() ? kDiffSeq : 0;
|
int seq_diff_flag = src[i].Seqno() != dst[i].Seqno() ? kDiffSeq : 0;
|
||||||
// apply or to both seq diff flag as well as cache seq diff flag
|
// apply or to both seq diff flag as well as cache seq diff flag
|
||||||
dst[i] = ActionSummary(action_flag | seq_diff_flag,
|
dst[i] = ActionSummary(action_flag | seq_diff_flag,
|
||||||
role_flag, min_seqno, max_seqno);
|
role_flag, min_seqno, max_seqno);
|
||||||
@ -316,19 +310,19 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
// internel sequence code min of rabit seqno
|
// internel sequence code min of rabit seqno
|
||||||
u_int32_t seqcode;
|
u_int32_t seqcode_;
|
||||||
// internal sequence code max of cache seqno
|
// internal sequence code max of cache seqno
|
||||||
u_int32_t maxseqcode;
|
u_int32_t maxseqcode_;
|
||||||
};
|
};
|
||||||
/*! \brief data structure to remember result of Bcast and Allreduce calls*/
|
/*! \brief data structure to remember result of Bcast and Allreduce calls*/
|
||||||
class ResultBuffer{
|
class ResultBuffer{
|
||||||
public:
|
public:
|
||||||
// constructor
|
// constructor
|
||||||
ResultBuffer(void) {
|
ResultBuffer() {
|
||||||
this->Clear();
|
this->Clear();
|
||||||
}
|
}
|
||||||
// clear the existing record
|
// clear the existing record
|
||||||
inline void Clear(void) {
|
inline void Clear() {
|
||||||
seqno_.clear(); size_.clear();
|
seqno_.clear(); size_.clear();
|
||||||
rptr_.clear(); rptr_.push_back(0);
|
rptr_.clear(); rptr_.push_back(0);
|
||||||
data_.clear();
|
data_.clear();
|
||||||
@ -358,12 +352,12 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
inline void* Query(int seqid, size_t *p_size) {
|
inline void* Query(int seqid, size_t *p_size) {
|
||||||
size_t idx = std::lower_bound(seqno_.begin(),
|
size_t idx = std::lower_bound(seqno_.begin(),
|
||||||
seqno_.end(), seqid) - seqno_.begin();
|
seqno_.end(), seqid) - seqno_.begin();
|
||||||
if (idx == seqno_.size() || seqno_[idx] != seqid) return NULL;
|
if (idx == seqno_.size() || seqno_[idx] != seqid) return nullptr;
|
||||||
*p_size = size_[idx];
|
*p_size = size_[idx];
|
||||||
return BeginPtr(data_) + rptr_[idx];
|
return BeginPtr(data_) + rptr_[idx];
|
||||||
}
|
}
|
||||||
// drop last stored result
|
// drop last stored result
|
||||||
inline void DropLast(void) {
|
inline void DropLast() {
|
||||||
utils::Assert(seqno_.size() != 0, "there is nothing to be dropped");
|
utils::Assert(seqno_.size() != 0, "there is nothing to be dropped");
|
||||||
seqno_.pop_back();
|
seqno_.pop_back();
|
||||||
rptr_.pop_back();
|
rptr_.pop_back();
|
||||||
@ -371,7 +365,7 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
data_.resize(rptr_.back());
|
data_.resize(rptr_.back());
|
||||||
}
|
}
|
||||||
// the sequence number of last stored result
|
// the sequence number of last stored result
|
||||||
inline int LastSeqNo(void) const {
|
inline int LastSeqNo() const {
|
||||||
if (seqno_.size() == 0) return -1;
|
if (seqno_.size() == 0) return -1;
|
||||||
return seqno_.back();
|
return seqno_.back();
|
||||||
}
|
}
|
||||||
@ -407,9 +401,8 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
*
|
*
|
||||||
* \sa CheckPoint, LazyCheckPoint
|
* \sa CheckPoint, LazyCheckPoint
|
||||||
*/
|
*/
|
||||||
void CheckPoint_(const Serializable *global_model,
|
void CheckPointImpl(const Serializable *global_model,
|
||||||
const Serializable *local_model,
|
const Serializable *local_model, bool lazy_checkpt);
|
||||||
bool lazy_checkpt);
|
|
||||||
/*!
|
/*!
|
||||||
* \brief reset the all the existing links by sending Out-of-Band message marker
|
* \brief reset the all the existing links by sending Out-of-Band message marker
|
||||||
* after this function finishes, all the messages received and sent
|
* after this function finishes, all the messages received and sent
|
||||||
@ -423,7 +416,7 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
* when kSockError is returned, it simply means there are bad sockets in the links,
|
* when kSockError is returned, it simply means there are bad sockets in the links,
|
||||||
* and some link recovery proceduer is needed
|
* and some link recovery proceduer is needed
|
||||||
*/
|
*/
|
||||||
ReturnType TryResetLinks(void);
|
ReturnType TryResetLinks();
|
||||||
/*!
|
/*!
|
||||||
* \brief if err_type indicates an error
|
* \brief if err_type indicates an error
|
||||||
* recover links according to the error type reported
|
* recover links according to the error type reported
|
||||||
@ -619,29 +612,29 @@ o * the input state must exactly one saved state(local state of current node)
|
|||||||
size_t out_index));
|
size_t out_index));
|
||||||
//---- recovery data structure ----
|
//---- recovery data structure ----
|
||||||
// the round of result buffer, used to mode the result
|
// the round of result buffer, used to mode the result
|
||||||
int result_buffer_round;
|
int result_buffer_round_;
|
||||||
// result buffer of all reduce
|
// result buffer of all reduce
|
||||||
ResultBuffer resbuf;
|
ResultBuffer resbuf_;
|
||||||
// current cached allreduce/braodcast sequence number
|
// current cached allreduce/braodcast sequence number
|
||||||
int cur_cache_seq;
|
int cur_cache_seq_;
|
||||||
// result buffer of cached all reduce
|
// result buffer of cached all reduce
|
||||||
ResultBuffer cachebuf;
|
ResultBuffer cachebuf_;
|
||||||
// key of each cache entry
|
// key of each cache entry
|
||||||
ResultBuffer lookupbuf;
|
ResultBuffer lookupbuf_;
|
||||||
// last check point global model
|
// last check point global model
|
||||||
std::string global_checkpoint;
|
std::string global_checkpoint_;
|
||||||
// lazy checkpoint of global model
|
// lazy checkpoint of global model
|
||||||
const Serializable *global_lazycheck;
|
const Serializable *global_lazycheck_;
|
||||||
// number of replica for local state/model
|
// number of replica for local state/model
|
||||||
int num_local_replica;
|
int num_local_replica_;
|
||||||
// number of default local replica
|
// number of default local replica
|
||||||
int default_local_replica;
|
int default_local_replica_;
|
||||||
// flag to decide whether local model is used, -1: unknown, 0: no, 1:yes
|
// flag to decide whether local model is used, -1: unknown, 0: no, 1:yes
|
||||||
int use_local_model;
|
int use_local_model_;
|
||||||
// number of replica for global state/model
|
// number of replica for global state/model
|
||||||
int num_global_replica;
|
int num_global_replica_;
|
||||||
// number of times recovery happens
|
// number of times recovery happens
|
||||||
int recover_counter;
|
int recover_counter_;
|
||||||
// --- recovery data structure for local checkpoint
|
// --- recovery data structure for local checkpoint
|
||||||
// there is two version of the data structure,
|
// there is two version of the data structure,
|
||||||
// at one time one version is valid and another is used as temp memory
|
// at one time one version is valid and another is used as temp memory
|
||||||
@ -649,21 +642,21 @@ o * the input state must exactly one saved state(local state of current node)
|
|||||||
// local model is stored in CSR format(like a sparse matrices)
|
// local model is stored in CSR format(like a sparse matrices)
|
||||||
// local_model[rptr[0]:rptr[1]] stores the model of current node
|
// local_model[rptr[0]:rptr[1]] stores the model of current node
|
||||||
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops
|
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops
|
||||||
std::vector<size_t> local_rptr[2];
|
std::vector<size_t> local_rptr_[2];
|
||||||
// storage for local model replicas
|
// storage for local model replicas
|
||||||
std::string local_chkpt[2];
|
std::string local_chkpt_[2];
|
||||||
// version of local checkpoint can be 1 or 0
|
// version of local checkpoint can be 1 or 0
|
||||||
int local_chkpt_version;
|
int local_chkpt_version_;
|
||||||
// if checkpoint were loaded, used to distinguish results boostrap cache from seqno cache
|
// if checkpoint were loaded, used to distinguish results boostrap cache from seqno cache
|
||||||
bool checkpoint_loaded;
|
bool checkpoint_loaded_;
|
||||||
// sidecar executing timeout task
|
// sidecar executing timeout task
|
||||||
std::future<bool> rabit_timeout_task;
|
std::future<bool> rabit_timeout_task_;
|
||||||
// flag to shutdown rabit_timeout_task before timeout
|
// flag to shutdown rabit_timeout_task before timeout
|
||||||
std::atomic<bool> shutdown_timeout{false};
|
std::atomic<bool> shutdown_timeout_{false};
|
||||||
// error handler
|
// error handler
|
||||||
void (* _error)(const char *fmt, ...) = utils::Error;
|
void (* error_)(const char *fmt, ...) = utils::Error;
|
||||||
// assert handler
|
// assert handler
|
||||||
void (* _assert)(bool exp, const char *fmt, ...) = utils::Assert;
|
void (* assert_)(bool exp, const char *fmt, ...) = utils::Assert;
|
||||||
};
|
};
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|||||||
@ -33,12 +33,12 @@ struct FHelper<op::BitOR, DType> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template<typename OP>
|
template<typename OP>
|
||||||
void Allreduce_(void *sendrecvbuf_,
|
void Allreduce(void *sendrecvbuf_,
|
||||||
size_t count,
|
size_t count,
|
||||||
engine::mpi::DataType enum_dtype,
|
engine::mpi::DataType enum_dtype,
|
||||||
void (*prepare_fun)(void *arg),
|
void (*prepare_fun)(void *arg),
|
||||||
void *prepare_arg) {
|
void *prepare_arg) {
|
||||||
using namespace engine::mpi;
|
using namespace engine::mpi; // NOLINT
|
||||||
switch (enum_dtype) {
|
switch (enum_dtype) {
|
||||||
case kChar:
|
case kChar:
|
||||||
rabit::Allreduce<OP>
|
rabit::Allreduce<OP>
|
||||||
@ -89,28 +89,28 @@ void Allreduce(void *sendrecvbuf,
|
|||||||
engine::mpi::OpType enum_op,
|
engine::mpi::OpType enum_op,
|
||||||
void (*prepare_fun)(void *arg),
|
void (*prepare_fun)(void *arg),
|
||||||
void *prepare_arg) {
|
void *prepare_arg) {
|
||||||
using namespace engine::mpi;
|
using namespace engine::mpi; // NOLINT
|
||||||
switch (enum_op) {
|
switch (enum_op) {
|
||||||
case kMax:
|
case kMax:
|
||||||
Allreduce_<op::Max>
|
Allreduce<op::Max>
|
||||||
(sendrecvbuf,
|
(sendrecvbuf,
|
||||||
count, enum_dtype,
|
count, enum_dtype,
|
||||||
prepare_fun, prepare_arg);
|
prepare_fun, prepare_arg);
|
||||||
return;
|
return;
|
||||||
case kMin:
|
case kMin:
|
||||||
Allreduce_<op::Min>
|
Allreduce<op::Min>
|
||||||
(sendrecvbuf,
|
(sendrecvbuf,
|
||||||
count, enum_dtype,
|
count, enum_dtype,
|
||||||
prepare_fun, prepare_arg);
|
prepare_fun, prepare_arg);
|
||||||
return;
|
return;
|
||||||
case kSum:
|
case kSum:
|
||||||
Allreduce_<op::Sum>
|
Allreduce<op::Sum>
|
||||||
(sendrecvbuf,
|
(sendrecvbuf,
|
||||||
count, enum_dtype,
|
count, enum_dtype,
|
||||||
prepare_fun, prepare_arg);
|
prepare_fun, prepare_arg);
|
||||||
return;
|
return;
|
||||||
case kBitwiseOR:
|
case kBitwiseOR:
|
||||||
Allreduce_<op::BitOR>
|
Allreduce<op::BitOR>
|
||||||
(sendrecvbuf,
|
(sendrecvbuf,
|
||||||
count, enum_dtype,
|
count, enum_dtype,
|
||||||
prepare_fun, prepare_arg);
|
prepare_fun, prepare_arg);
|
||||||
@ -124,7 +124,7 @@ void Allgather(void *sendrecvbuf_,
|
|||||||
size_t size_node_slice,
|
size_t size_node_slice,
|
||||||
size_t size_prev_slice,
|
size_t size_prev_slice,
|
||||||
int enum_dtype) {
|
int enum_dtype) {
|
||||||
using namespace engine::mpi;
|
using namespace engine::mpi; // NOLINT
|
||||||
size_t type_size = 0;
|
size_t type_size = 0;
|
||||||
switch (enum_dtype) {
|
switch (enum_dtype) {
|
||||||
case kChar:
|
case kChar:
|
||||||
@ -184,7 +184,7 @@ struct ReadWrapper : public Serializable {
|
|||||||
std::string *p_str;
|
std::string *p_str;
|
||||||
explicit ReadWrapper(std::string *p_str)
|
explicit ReadWrapper(std::string *p_str)
|
||||||
: p_str(p_str) {}
|
: p_str(p_str) {}
|
||||||
virtual void Load(Stream *fi) {
|
void Load(Stream *fi) override {
|
||||||
uint64_t sz;
|
uint64_t sz;
|
||||||
utils::Assert(fi->Read(&sz, sizeof(sz)) != 0,
|
utils::Assert(fi->Read(&sz, sizeof(sz)) != 0,
|
||||||
"Read pickle string");
|
"Read pickle string");
|
||||||
@ -194,7 +194,7 @@ struct ReadWrapper : public Serializable {
|
|||||||
"Read pickle string");
|
"Read pickle string");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
virtual void Save(Stream *fo) const {
|
void Save(Stream *fo) const override {
|
||||||
utils::Error("not implemented");
|
utils::Error("not implemented");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -206,10 +206,10 @@ struct WriteWrapper : public Serializable {
|
|||||||
size_t length)
|
size_t length)
|
||||||
: data(data), length(length) {
|
: data(data), length(length) {
|
||||||
}
|
}
|
||||||
virtual void Load(Stream *fi) {
|
void Load(Stream *fi) override {
|
||||||
utils::Error("not implemented");
|
utils::Error("not implemented");
|
||||||
}
|
}
|
||||||
virtual void Save(Stream *fo) const {
|
void Save(Stream *fo) const override {
|
||||||
uint64_t sz = static_cast<uint16_t>(length);
|
uint64_t sz = static_cast<uint16_t>(length);
|
||||||
fo->Write(&sz, sizeof(sz));
|
fo->Write(&sz, sizeof(sz));
|
||||||
fo->Write(data, length * sizeof(char));
|
fo->Write(data, length * sizeof(char));
|
||||||
@ -298,8 +298,8 @@ RABIT_DLL int RabitLoadCheckPoint(char **out_global_model,
|
|||||||
ReadWrapper sl(&local_buffer);
|
ReadWrapper sl(&local_buffer);
|
||||||
int version;
|
int version;
|
||||||
|
|
||||||
if (out_local_model == NULL) {
|
if (out_local_model == nullptr) {
|
||||||
version = rabit::LoadCheckPoint(&sg, NULL);
|
version = rabit::LoadCheckPoint(&sg, nullptr);
|
||||||
*out_global_model = BeginPtr(global_buffer);
|
*out_global_model = BeginPtr(global_buffer);
|
||||||
*out_global_len = static_cast<rbt_ulong>(global_buffer.length());
|
*out_global_len = static_cast<rbt_ulong>(global_buffer.length());
|
||||||
} else {
|
} else {
|
||||||
@ -317,8 +317,8 @@ RABIT_DLL void RabitCheckPoint(const char *global_model, rbt_ulong global_len,
|
|||||||
using namespace rabit::c_api; // NOLINT(*)
|
using namespace rabit::c_api; // NOLINT(*)
|
||||||
WriteWrapper sg(global_model, global_len);
|
WriteWrapper sg(global_model, global_len);
|
||||||
WriteWrapper sl(local_model, local_len);
|
WriteWrapper sl(local_model, local_len);
|
||||||
if (local_model == NULL) {
|
if (local_model == nullptr) {
|
||||||
rabit::CheckPoint(&sg, NULL);
|
rabit::CheckPoint(&sg, nullptr);
|
||||||
} else {
|
} else {
|
||||||
rabit::CheckPoint(&sg, &sl);
|
rabit::CheckPoint(&sg, &sl);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7,18 +7,19 @@
|
|||||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||||
*/
|
*/
|
||||||
#include <rabit/base.h>
|
#include <rabit/base.h>
|
||||||
|
#include <dmlc/thread_local.h>
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "rabit/internal/engine.h"
|
#include "rabit/internal/engine.h"
|
||||||
#include "allreduce_base.h"
|
#include "allreduce_base.h"
|
||||||
#include "allreduce_robust.h"
|
#include "allreduce_robust.h"
|
||||||
#include "rabit/internal/thread_local.h"
|
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace engine {
|
namespace engine {
|
||||||
// singleton sync manager
|
// singleton sync manager
|
||||||
#ifndef RABIT_USE_BASE
|
#ifndef RABIT_USE_BASE
|
||||||
#ifndef RABIT_USE_MOCK
|
#ifndef RABIT_USE_MOCK
|
||||||
typedef AllreduceRobust Manager;
|
using Manager = AllreduceRobust;
|
||||||
#else
|
#else
|
||||||
typedef AllreduceMock Manager;
|
typedef AllreduceMock Manager;
|
||||||
#endif // RABIT_USE_MOCK
|
#endif // RABIT_USE_MOCK
|
||||||
@ -31,13 +32,13 @@ struct ThreadLocalEntry {
|
|||||||
/*! \brief stores the current engine */
|
/*! \brief stores the current engine */
|
||||||
std::unique_ptr<Manager> engine;
|
std::unique_ptr<Manager> engine;
|
||||||
/*! \brief whether init has been called */
|
/*! \brief whether init has been called */
|
||||||
bool initialized;
|
bool initialized{false};
|
||||||
/*! \brief constructor */
|
/*! \brief constructor */
|
||||||
ThreadLocalEntry() : initialized(false) {}
|
ThreadLocalEntry() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
// define the threadlocal store.
|
// define the threadlocal store.
|
||||||
typedef ThreadLocalStore<ThreadLocalEntry> EngineThreadLocal;
|
using EngineThreadLocal = dmlc::ThreadLocalStore<ThreadLocalEntry>;
|
||||||
|
|
||||||
/*! \brief intiialize the synchronization module */
|
/*! \brief intiialize the synchronization module */
|
||||||
bool Init(int argc, char *argv[]) {
|
bool Init(int argc, char *argv[]) {
|
||||||
@ -95,7 +96,7 @@ void Allgather(void *sendrecvbuf_, size_t total_size,
|
|||||||
|
|
||||||
|
|
||||||
// perform in-place allreduce, on sendrecvbuf
|
// perform in-place allreduce, on sendrecvbuf
|
||||||
void Allreduce_(void *sendrecvbuf,
|
void Allreduce_(void *sendrecvbuf, // NOLINT
|
||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
IEngine::ReduceFunction red,
|
IEngine::ReduceFunction red,
|
||||||
@ -111,18 +112,15 @@ void Allreduce_(void *sendrecvbuf,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// code for reduce handle
|
// code for reduce handle
|
||||||
ReduceHandle::ReduceHandle(void)
|
ReduceHandle::ReduceHandle() = default;
|
||||||
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
|
ReduceHandle::~ReduceHandle() = default;
|
||||||
}
|
|
||||||
|
|
||||||
ReduceHandle::~ReduceHandle(void) {}
|
|
||||||
|
|
||||||
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
||||||
return static_cast<int>(dtype.type_size);
|
return static_cast<int>(dtype.type_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
|
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
|
||||||
utils::Assert(redfunc_ == NULL, "cannot initialize reduce handle twice");
|
utils::Assert(redfunc_ == nullptr, "cannot initialize reduce handle twice");
|
||||||
redfunc_ = redfunc;
|
redfunc_ = redfunc;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -133,7 +131,7 @@ void ReduceHandle::Allreduce(void *sendrecvbuf,
|
|||||||
const char* _file,
|
const char* _file,
|
||||||
const int _line,
|
const int _line,
|
||||||
const char* _caller) {
|
const char* _caller) {
|
||||||
utils::Assert(redfunc_ != NULL, "must intialize handle to call AllReduce");
|
utils::Assert(redfunc_ != nullptr, "must intialize handle to call AllReduce");
|
||||||
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
|
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
|
||||||
redfunc_, prepare_fun, prepare_arg,
|
redfunc_, prepare_fun, prepare_arg,
|
||||||
_file, _line, _caller);
|
_file, _line, _caller);
|
||||||
|
|||||||
@ -16,78 +16,68 @@ namespace engine {
|
|||||||
/*! \brief EmptyEngine */
|
/*! \brief EmptyEngine */
|
||||||
class EmptyEngine : public IEngine {
|
class EmptyEngine : public IEngine {
|
||||||
public:
|
public:
|
||||||
EmptyEngine(void) {
|
EmptyEngine() {
|
||||||
version_number = 0;
|
version_number_ = 0;
|
||||||
}
|
}
|
||||||
virtual void Allgather(void *sendrecvbuf_,
|
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
|
||||||
size_t total_size,
|
size_t slice_end, size_t size_prev_slice, const char *_file,
|
||||||
size_t slice_begin,
|
const int _line, const char *_caller) override {
|
||||||
size_t slice_end,
|
|
||||||
size_t size_prev_slice,
|
|
||||||
const char* _file,
|
|
||||||
const int _line,
|
|
||||||
const char* _caller) {
|
|
||||||
utils::Error("EmptyEngine:: Allgather is not supported");
|
utils::Error("EmptyEngine:: Allgather is not supported");
|
||||||
}
|
}
|
||||||
virtual int GetRingPrevRank(void) const {
|
int GetRingPrevRank() const override {
|
||||||
utils::Error("EmptyEngine:: GetRingPrevRank is not supported");
|
utils::Error("EmptyEngine:: GetRingPrevRank is not supported");
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
virtual void Allreduce(void *sendrecvbuf_,
|
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
|
||||||
size_t type_nbytes,
|
ReduceFunction reducer, PreprocFunction prepare_fun,
|
||||||
size_t count,
|
void *prepare_arg, const char *_file, const int _line,
|
||||||
ReduceFunction reducer,
|
const char *_caller) override {
|
||||||
PreprocFunction prepare_fun,
|
|
||||||
void *prepare_arg,
|
|
||||||
const char* _file,
|
|
||||||
const int _line,
|
|
||||||
const char* _caller) {
|
|
||||||
utils::Error("EmptyEngine:: Allreduce is not supported,"\
|
utils::Error("EmptyEngine:: Allreduce is not supported,"\
|
||||||
"use Allreduce_ instead");
|
"use Allreduce_ instead");
|
||||||
}
|
}
|
||||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root,
|
void Broadcast(void *sendrecvbuf_, size_t size, int root,
|
||||||
const char* _file, const int _line, const char* _caller) {
|
const char* _file, const int _line, const char* _caller) override {
|
||||||
}
|
}
|
||||||
virtual void InitAfterException(void) {
|
void InitAfterException() override {
|
||||||
utils::Error("EmptyEngine is not fault tolerant");
|
utils::Error("EmptyEngine is not fault tolerant");
|
||||||
}
|
}
|
||||||
virtual int LoadCheckPoint(Serializable *global_model,
|
int LoadCheckPoint(Serializable *global_model,
|
||||||
Serializable *local_model = NULL) {
|
Serializable *local_model = nullptr) override {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
virtual void CheckPoint(const Serializable *global_model,
|
void CheckPoint(const Serializable *global_model,
|
||||||
const Serializable *local_model = NULL) {
|
const Serializable *local_model = nullptr) override {
|
||||||
version_number += 1;
|
version_number_ += 1;
|
||||||
}
|
}
|
||||||
virtual void LazyCheckPoint(const Serializable *global_model) {
|
void LazyCheckPoint(const Serializable *global_model) override {
|
||||||
version_number += 1;
|
version_number_ += 1;
|
||||||
}
|
}
|
||||||
virtual int VersionNumber(void) const {
|
int VersionNumber() const override {
|
||||||
return version_number;
|
return version_number_;
|
||||||
}
|
}
|
||||||
/*! \brief get rank of current node */
|
/*! \brief get rank of current node */
|
||||||
virtual int GetRank(void) const {
|
int GetRank() const override {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
/*! \brief get total number of */
|
/*! \brief get total number of */
|
||||||
virtual int GetWorldSize(void) const {
|
int GetWorldSize() const override {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
/*! \brief whether it is distributed */
|
/*! \brief whether it is distributed */
|
||||||
virtual bool IsDistributed(void) const {
|
bool IsDistributed() const override {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
/*! \brief get the host name of current node */
|
/*! \brief get the host name of current node */
|
||||||
virtual std::string GetHost(void) const {
|
std::string GetHost() const override {
|
||||||
return std::string("");
|
return std::string("");
|
||||||
}
|
}
|
||||||
virtual void TrackerPrint(const std::string &msg) {
|
void TrackerPrint(const std::string &msg) override {
|
||||||
// simply print information into the tracker
|
// simply print information into the tracker
|
||||||
utils::Printf("%s", msg.c_str());
|
utils::Printf("%s", msg.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int version_number;
|
int version_number_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// singleton sync manager
|
// singleton sync manager
|
||||||
@ -98,12 +88,12 @@ bool Init(int argc, char *argv[]) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
/*! \brief finalize syncrhonization module */
|
/*! \brief finalize syncrhonization module */
|
||||||
bool Finalize(void) {
|
bool Finalize() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief singleton method to get engine */
|
/*! \brief singleton method to get engine */
|
||||||
IEngine *GetEngine(void) {
|
IEngine *GetEngine() {
|
||||||
return &manager;
|
return &manager;
|
||||||
}
|
}
|
||||||
// perform in-place allreduce, on sendrecvbuf
|
// perform in-place allreduce, on sendrecvbuf
|
||||||
@ -118,13 +108,12 @@ void Allreduce_(void *sendrecvbuf,
|
|||||||
const char* _file,
|
const char* _file,
|
||||||
const int _line,
|
const int _line,
|
||||||
const char* _caller) {
|
const char* _caller) {
|
||||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
// code for reduce handle
|
// code for reduce handle
|
||||||
ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) {
|
ReduceHandle::ReduceHandle() = default;
|
||||||
}
|
ReduceHandle::~ReduceHandle() = default;
|
||||||
ReduceHandle::~ReduceHandle(void) {}
|
|
||||||
|
|
||||||
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
||||||
return 0;
|
return 0;
|
||||||
@ -137,7 +126,7 @@ void ReduceHandle::Allreduce(void *sendrecvbuf,
|
|||||||
const char* _file,
|
const char* _file,
|
||||||
const int _line,
|
const int _line,
|
||||||
const char* _caller) {
|
const char* _caller) {
|
||||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
|
||||||
}
|
}
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|||||||
@ -12,4 +12,3 @@
|
|||||||
#include <rabit/base.h>
|
#include <rabit/base.h>
|
||||||
#include "allreduce_mock.h"
|
#include "allreduce_mock.h"
|
||||||
#include "engine.cc"
|
#include "engine.cc"
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user