Merge branch 'master' of https://github.com/tqchen/allreduce
This commit is contained in:
commit
34c8253ad6
5
Makefile
5
Makefile
@ -2,7 +2,8 @@ export CC = gcc
|
|||||||
export CXX = g++
|
export CXX = g++
|
||||||
export MPICXX = mpicxx
|
export MPICXX = mpicxx
|
||||||
export LDFLAGS= -Llib
|
export LDFLAGS= -Llib
|
||||||
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -Iinclude
|
export WARNFLAGS= -Wall -Wextra -Wno-unused-parameter -Wno-unknown-pragmas -pedantic
|
||||||
|
export CFLAGS = -O3 -msse2 -fPIC -Iinclude $(WARNFLAGS)
|
||||||
|
|
||||||
# build path
|
# build path
|
||||||
BPATH=.
|
BPATH=.
|
||||||
@ -15,7 +16,7 @@ ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a lib/librabit_mock.a
|
|||||||
HEADERS=src/*.h include/*.h include/rabit/*.h
|
HEADERS=src/*.h include/*.h include/rabit/*.h
|
||||||
.PHONY: clean all install mpi python
|
.PHONY: clean all install mpi python
|
||||||
|
|
||||||
all: lib/librabit.a lib/librabit_mock.a $(SLIB)
|
all: lib/librabit.a lib/librabit_mock.a wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so
|
||||||
mpi: lib/librabit_mpi.a wrapper/librabit_wrapper_mpi.so
|
mpi: lib/librabit_mpi.a wrapper/librabit_wrapper_mpi.so
|
||||||
python: wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so
|
python: wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so
|
||||||
|
|
||||||
|
|||||||
@ -203,6 +203,27 @@ inline int LoadCheckPoint(ISerializable *global_model,
|
|||||||
*/
|
*/
|
||||||
inline void CheckPoint(const ISerializable *global_model,
|
inline void CheckPoint(const ISerializable *global_model,
|
||||||
const ISerializable *local_model = NULL);
|
const ISerializable *local_model = NULL);
|
||||||
|
/*!
|
||||||
|
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||||
|
* when certain condition is met(see detailed expplaination).
|
||||||
|
*
|
||||||
|
* This is a "lazy" checkpoint such that only the pointer to global_model is
|
||||||
|
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
|
||||||
|
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
|
||||||
|
* In another words, global_model model can be changed only between last call of
|
||||||
|
* Allreduce/Broadcast and LazyCheckPoint in current version
|
||||||
|
*
|
||||||
|
* For example, suppose the calling sequence is:
|
||||||
|
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
|
||||||
|
*
|
||||||
|
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
|
||||||
|
* improve efficiency of the program.
|
||||||
|
* \param global_model pointer to the globally shared model/state
|
||||||
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
|
* is the same in all nodes
|
||||||
|
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||||
|
*/
|
||||||
|
inline void LazyCheckPoint(const ISerializable *global_model);
|
||||||
/*!
|
/*!
|
||||||
* \return version number of current stored model,
|
* \return version number of current stored model,
|
||||||
* which means how many calls to CheckPoint we made so far
|
* which means how many calls to CheckPoint we made so far
|
||||||
|
|||||||
@ -114,6 +114,27 @@ class IEngine {
|
|||||||
*/
|
*/
|
||||||
virtual void CheckPoint(const ISerializable *global_model,
|
virtual void CheckPoint(const ISerializable *global_model,
|
||||||
const ISerializable *local_model = NULL) = 0;
|
const ISerializable *local_model = NULL) = 0;
|
||||||
|
/*!
|
||||||
|
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||||
|
* when certain condition is met(see detailed expplaination.
|
||||||
|
*
|
||||||
|
* This is a "lazy" checkpoint such that only the pointer to global_model is
|
||||||
|
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
|
||||||
|
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
|
||||||
|
* In another words, global_model model can be changed only between last call of
|
||||||
|
* Allreduce/Broadcast and LazyCheckPoint in current version
|
||||||
|
*
|
||||||
|
* For example, suppose the calling sequence is:
|
||||||
|
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
|
||||||
|
*
|
||||||
|
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
|
||||||
|
* improve efficiency of the program.
|
||||||
|
* \param global_model pointer to the globally shared model/state
|
||||||
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
|
* is the same in all nodes
|
||||||
|
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||||
|
*/
|
||||||
|
virtual void LazyCheckPoint(const ISerializable *global_model) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \return version number of current stored model,
|
* \return version number of current stored model,
|
||||||
* which means how many calls to CheckPoint we made so far
|
* which means how many calls to CheckPoint we made so far
|
||||||
@ -221,8 +242,10 @@ class ReduceHandle {
|
|||||||
static int TypeSize(const MPI::Datatype &dtype);
|
static int TypeSize(const MPI::Datatype &dtype);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// handle data field
|
// handle function field
|
||||||
void *handle_;
|
void *handle_;
|
||||||
|
// reduce function of the reducer
|
||||||
|
IEngine::ReduceFunction *redfunc_;
|
||||||
// handle to the type field
|
// handle to the type field
|
||||||
void *htype_;
|
void *htype_;
|
||||||
// the created type in 4 bytes
|
// the created type in 4 bytes
|
||||||
|
|||||||
@ -183,6 +183,10 @@ inline void CheckPoint(const ISerializable *global_model,
|
|||||||
const ISerializable *local_model) {
|
const ISerializable *local_model) {
|
||||||
engine::GetEngine()->CheckPoint(global_model, local_model);
|
engine::GetEngine()->CheckPoint(global_model, local_model);
|
||||||
}
|
}
|
||||||
|
// lazy checkpoint the model, only remember the pointer to global_model
|
||||||
|
inline void LazyCheckPoint(const ISerializable *global_model) {
|
||||||
|
engine::GetEngine()->LazyCheckPoint(global_model);
|
||||||
|
}
|
||||||
// return the version number of currently stored model
|
// return the version number of currently stored model
|
||||||
inline int VersionNumber(void) {
|
inline int VersionNumber(void) {
|
||||||
return engine::GetEngine()->VersionNumber();
|
return engine::GetEngine()->VersionNumber();
|
||||||
@ -197,7 +201,7 @@ inline void ReducerFunc_(const void *src_, void *dst_, int len_, const MPI::Data
|
|||||||
const char *psrc = reinterpret_cast<const char*>(src_);
|
const char *psrc = reinterpret_cast<const char*>(src_);
|
||||||
char *pdst = reinterpret_cast<char*>(dst_);
|
char *pdst = reinterpret_cast<char*>(dst_);
|
||||||
DType tdst, tsrc;
|
DType tdst, tsrc;
|
||||||
for (size_t i = 0; i < len_; ++i) {
|
for (int i = 0; i < len_; ++i) {
|
||||||
// use memcpy to avoid alignment issue
|
// use memcpy to avoid alignment issue
|
||||||
std::memcpy(&tdst, pdst + i * kUnit, sizeof(tdst));
|
std::memcpy(&tdst, pdst + i * kUnit, sizeof(tdst));
|
||||||
std::memcpy(&tsrc, psrc + i * kUnit, sizeof(tsrc));
|
std::memcpy(&tsrc, psrc + i * kUnit, sizeof(tsrc));
|
||||||
|
|||||||
@ -27,6 +27,7 @@ AllreduceBase::AllreduceBase(void) {
|
|||||||
hadoop_mode = 0;
|
hadoop_mode = 0;
|
||||||
version_number = 0;
|
version_number = 0;
|
||||||
task_id = "NULL";
|
task_id = "NULL";
|
||||||
|
err_link = NULL;
|
||||||
this->SetParam("rabit_reduce_buffer", "256MB");
|
this->SetParam("rabit_reduce_buffer", "256MB");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -213,7 +214,7 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
} else {
|
} else {
|
||||||
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
|
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int ngood = static_cast<int>(good_link.size());
|
int ngood = static_cast<int>(good_link.size());
|
||||||
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
|
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
|
||||||
"ReConnectLink failure 5");
|
"ReConnectLink failure 5");
|
||||||
@ -365,7 +366,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
|||||||
selecter.WatchException(links[i].sock);
|
selecter.WatchException(links[i].sock);
|
||||||
finished = false;
|
finished = false;
|
||||||
}
|
}
|
||||||
if (size_up_out != total_size) {
|
if (size_up_out != total_size && size_up_out < size_up_reduce) {
|
||||||
selecter.WatchWrite(links[i].sock);
|
selecter.WatchWrite(links[i].sock);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -373,8 +374,10 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
|||||||
selecter.WatchRead(links[i].sock);
|
selecter.WatchRead(links[i].sock);
|
||||||
}
|
}
|
||||||
// size_write <= size_read
|
// size_write <= size_read
|
||||||
if (links[i].size_write != total_size) {
|
if (links[i].size_write != total_size){
|
||||||
selecter.WatchWrite(links[i].sock);
|
if (links[i].size_write < size_down_in) {
|
||||||
|
selecter.WatchWrite(links[i].sock);
|
||||||
|
}
|
||||||
// only watch for exception in live channels
|
// only watch for exception in live channels
|
||||||
selecter.WatchException(links[i].sock);
|
selecter.WatchException(links[i].sock);
|
||||||
finished = false;
|
finished = false;
|
||||||
@ -388,12 +391,17 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
|||||||
// exception handling
|
// exception handling
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
// recive OOB message from some link
|
// recive OOB message from some link
|
||||||
if (selecter.CheckExcept(links[i].sock)) return kGetExcept;
|
if (selecter.CheckExcept(links[i].sock)) {
|
||||||
|
return ReportError(&links[i], kGetExcept);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// read data from childs
|
// read data from childs
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (i != parent_index && selecter.CheckRead(links[i].sock)) {
|
if (i != parent_index && selecter.CheckRead(links[i].sock)) {
|
||||||
if (!links[i].ReadToRingBuffer(size_up_out)) return kSockError;
|
ReturnType ret = links[i].ReadToRingBuffer(size_up_out);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
return ReportError(&links[i], ret);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// this node have childs, peform reduce
|
// this node have childs, peform reduce
|
||||||
@ -433,13 +441,16 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
|||||||
}
|
}
|
||||||
if (parent_index != -1) {
|
if (parent_index != -1) {
|
||||||
// pass message up to parent, can pass data that are already been reduced
|
// pass message up to parent, can pass data that are already been reduced
|
||||||
if (selecter.CheckWrite(links[parent_index].sock)) {
|
if (size_up_out < size_up_reduce) {
|
||||||
ssize_t len = links[parent_index].sock.
|
ssize_t len = links[parent_index].sock.
|
||||||
Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out);
|
Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out);
|
||||||
if (len != -1) {
|
if (len != -1) {
|
||||||
size_up_out += static_cast<size_t>(len);
|
size_up_out += static_cast<size_t>(len);
|
||||||
} else {
|
} else {
|
||||||
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
|
ReturnType ret = Errno2Return(errno);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
return ReportError(&links[parent_index], ret);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// read data from parent
|
// read data from parent
|
||||||
@ -448,14 +459,18 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
|||||||
ssize_t len = links[parent_index].sock.
|
ssize_t len = links[parent_index].sock.
|
||||||
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
|
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
|
||||||
if (len == 0) {
|
if (len == 0) {
|
||||||
links[parent_index].sock.Close(); return kSockError;
|
links[parent_index].sock.Close();
|
||||||
|
return ReportError(&links[parent_index], kRecvZeroLen);
|
||||||
}
|
}
|
||||||
if (len != -1) {
|
if (len != -1) {
|
||||||
size_down_in += static_cast<size_t>(len);
|
size_down_in += static_cast<size_t>(len);
|
||||||
utils::Assert(size_down_in <= size_up_out,
|
utils::Assert(size_down_in <= size_up_out,
|
||||||
"Allreduce: boundary error");
|
"Allreduce: boundary error");
|
||||||
} else {
|
} else {
|
||||||
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
|
ReturnType ret = Errno2Return(errno);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
return ReportError(&links[parent_index], ret);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -464,9 +479,10 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
|||||||
}
|
}
|
||||||
// can pass message down to childs
|
// can pass message down to childs
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (i != parent_index && selecter.CheckWrite(links[i].sock)) {
|
if (i != parent_index && links[i].size_write < size_down_in) {
|
||||||
if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) {
|
ReturnType ret = links[i].WriteFromArray(sendrecvbuf, size_down_in);
|
||||||
return kSockError;
|
if (ret != kSuccess) {
|
||||||
|
return ReportError(&links[i], ret);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -516,7 +532,10 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
|||||||
selecter.WatchRead(links[i].sock); finished = false;
|
selecter.WatchRead(links[i].sock); finished = false;
|
||||||
}
|
}
|
||||||
if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
|
if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
|
||||||
selecter.WatchWrite(links[i].sock); finished = false;
|
if (links[i].size_write < size_in) {
|
||||||
|
selecter.WatchWrite(links[i].sock);
|
||||||
|
}
|
||||||
|
finished = false;
|
||||||
}
|
}
|
||||||
selecter.WatchException(links[i].sock);
|
selecter.WatchException(links[i].sock);
|
||||||
}
|
}
|
||||||
@ -527,14 +546,17 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
|||||||
// exception handling
|
// exception handling
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
// recive OOB message from some link
|
// recive OOB message from some link
|
||||||
if (selecter.CheckExcept(links[i].sock)) return kGetExcept;
|
if (selecter.CheckExcept(links[i].sock)) {
|
||||||
|
return ReportError(&links[i], kGetExcept);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (in_link == -2) {
|
if (in_link == -2) {
|
||||||
// probe in-link
|
// probe in-link
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (selecter.CheckRead(links[i].sock)) {
|
if (selecter.CheckRead(links[i].sock)) {
|
||||||
if (!links[i].ReadToArray(sendrecvbuf_, total_size)) {
|
ReturnType ret = links[i].ReadToArray(sendrecvbuf_, total_size);
|
||||||
return kSockError;
|
if (ret != kSuccess) {
|
||||||
|
return ReportError(&links[i], ret);
|
||||||
}
|
}
|
||||||
size_in = links[i].size_read;
|
size_in = links[i].size_read;
|
||||||
if (size_in != 0) {
|
if (size_in != 0) {
|
||||||
@ -545,16 +567,20 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
|||||||
} else {
|
} else {
|
||||||
// read from in link
|
// read from in link
|
||||||
if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) {
|
if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) {
|
||||||
if (!links[in_link].ReadToArray(sendrecvbuf_, total_size)) {
|
ReturnType ret = links[in_link].ReadToArray(sendrecvbuf_, total_size);
|
||||||
return kSockError;
|
if (ret != kSuccess) {
|
||||||
|
return ReportError(&links[in_link], ret);
|
||||||
}
|
}
|
||||||
size_in = links[in_link].size_read;
|
size_in = links[in_link].size_read;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// send data to all out-link
|
// send data to all out-link
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (i != in_link && selecter.CheckWrite(links[i].sock)) {
|
if (i != in_link && links[i].size_write < size_in) {
|
||||||
if (!links[i].WriteFromArray(sendrecvbuf_, size_in)) return kSockError;
|
ReturnType ret = links[i].WriteFromArray(sendrecvbuf_, size_in);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
return ReportError(&links[i], ret);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -146,6 +146,29 @@ class AllreduceBase : public IEngine {
|
|||||||
const ISerializable *local_model = NULL) {
|
const ISerializable *local_model = NULL) {
|
||||||
version_number += 1;
|
version_number += 1;
|
||||||
}
|
}
|
||||||
|
/*!
|
||||||
|
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||||
|
* when certain condition is met(see detailed expplaination).
|
||||||
|
*
|
||||||
|
* This is a "lazy" checkpoint such that only the pointer to global_model is
|
||||||
|
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
|
||||||
|
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
|
||||||
|
* In another words, global_model model can be changed only between last call of
|
||||||
|
* Allreduce/Broadcast and LazyCheckPoint in current version
|
||||||
|
*
|
||||||
|
* For example, suppose the calling sequence is:
|
||||||
|
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
|
||||||
|
*
|
||||||
|
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
|
||||||
|
* improve efficiency of the program.
|
||||||
|
* \param global_model pointer to the globally shared model/state
|
||||||
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
|
* is the same in all nodes
|
||||||
|
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||||
|
*/
|
||||||
|
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||||
|
version_number += 1;
|
||||||
|
}
|
||||||
/*!
|
/*!
|
||||||
* \return version number of current stored model,
|
* \return version number of current stored model,
|
||||||
* which means how many calls to CheckPoint we made so far
|
* which means how many calls to CheckPoint we made so far
|
||||||
@ -175,9 +198,13 @@ class AllreduceBase : public IEngine {
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
/*! \brief enumeration of possible returning results from Try functions */
|
/*! \brief enumeration of possible returning results from Try functions */
|
||||||
enum ReturnType {
|
enum ReturnTypeEnum {
|
||||||
/*! \brief execution is successful */
|
/*! \brief execution is successful */
|
||||||
kSuccess,
|
kSuccess,
|
||||||
|
/*! \brief a link was reset by peer */
|
||||||
|
kConnReset,
|
||||||
|
/*! \brief received a zero length message */
|
||||||
|
kRecvZeroLen,
|
||||||
/*! \brief a neighbor node go down, the connection is dropped */
|
/*! \brief a neighbor node go down, the connection is dropped */
|
||||||
kSockError,
|
kSockError,
|
||||||
/*!
|
/*!
|
||||||
@ -186,6 +213,26 @@ class AllreduceBase : public IEngine {
|
|||||||
*/
|
*/
|
||||||
kGetExcept
|
kGetExcept
|
||||||
};
|
};
|
||||||
|
/*! \brief struct return type to avoid implicit conversion to int/bool */
|
||||||
|
struct ReturnType {
|
||||||
|
/*! \brief internal return type */
|
||||||
|
ReturnTypeEnum value;
|
||||||
|
// constructor
|
||||||
|
ReturnType() {}
|
||||||
|
ReturnType(ReturnTypeEnum value) : value(value){}
|
||||||
|
inline bool operator==(const ReturnTypeEnum &v) const {
|
||||||
|
return value == v;
|
||||||
|
}
|
||||||
|
inline bool operator!=(const ReturnTypeEnum &v) const {
|
||||||
|
return value != v;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
/*! \brief translate errno to return type */
|
||||||
|
inline static ReturnType Errno2Return(int errsv) {
|
||||||
|
if (errsv == EAGAIN || errsv == EWOULDBLOCK) return kSuccess;
|
||||||
|
if (errsv == ECONNRESET) return kConnReset;
|
||||||
|
return kSockError;
|
||||||
|
}
|
||||||
// link record to a neighbor
|
// link record to a neighbor
|
||||||
struct LinkRecord {
|
struct LinkRecord {
|
||||||
public:
|
public:
|
||||||
@ -202,7 +249,9 @@ class AllreduceBase : public IEngine {
|
|||||||
// buffer size, in bytes
|
// buffer size, in bytes
|
||||||
size_t buffer_size;
|
size_t buffer_size;
|
||||||
// constructor
|
// constructor
|
||||||
LinkRecord(void) {}
|
LinkRecord(void)
|
||||||
|
: buffer_head(NULL), buffer_size(0) {
|
||||||
|
}
|
||||||
// initialize buffer
|
// initialize buffer
|
||||||
inline void InitBuffer(size_t type_nbytes, size_t count,
|
inline void InitBuffer(size_t type_nbytes, size_t count,
|
||||||
size_t reduce_buffer_size) {
|
size_t reduce_buffer_size) {
|
||||||
@ -226,22 +275,23 @@ class AllreduceBase : public IEngine {
|
|||||||
* position after protect_start
|
* position after protect_start
|
||||||
* \param protect_start all data start from protect_start is still needed in buffer
|
* \param protect_start all data start from protect_start is still needed in buffer
|
||||||
* read shall not override this
|
* read shall not override this
|
||||||
* \return true if it is an successful read, false if there is some error happens, check errno
|
* \return the type of reading
|
||||||
*/
|
*/
|
||||||
inline bool ReadToRingBuffer(size_t protect_start) {
|
inline ReturnType ReadToRingBuffer(size_t protect_start) {
|
||||||
|
utils::Assert(buffer_head != NULL, "ReadToRingBuffer: buffer not allocated");
|
||||||
size_t ngap = size_read - protect_start;
|
size_t ngap = size_read - protect_start;
|
||||||
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
|
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
|
||||||
size_t offset = size_read % buffer_size;
|
size_t offset = size_read % buffer_size;
|
||||||
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset);
|
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset);
|
||||||
if (nmax == 0) return true;
|
if (nmax == 0) return kSuccess;
|
||||||
ssize_t len = sock.Recv(buffer_head + offset, nmax);
|
ssize_t len = sock.Recv(buffer_head + offset, nmax);
|
||||||
// length equals 0, remote disconnected
|
// length equals 0, remote disconnected
|
||||||
if (len == 0) {
|
if (len == 0) {
|
||||||
sock.Close(); return false;
|
sock.Close(); return kRecvZeroLen;
|
||||||
}
|
}
|
||||||
if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK;
|
if (len == -1) return Errno2Return(errno);
|
||||||
size_read += static_cast<size_t>(len);
|
size_read += static_cast<size_t>(len);
|
||||||
return true;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief read data into array,
|
* \brief read data into array,
|
||||||
@ -250,17 +300,17 @@ class AllreduceBase : public IEngine {
|
|||||||
* \param max_size maximum size of array
|
* \param max_size maximum size of array
|
||||||
* \return true if it is an successful read, false if there is some error happens, check errno
|
* \return true if it is an successful read, false if there is some error happens, check errno
|
||||||
*/
|
*/
|
||||||
inline bool ReadToArray(void *recvbuf_, size_t max_size) {
|
inline ReturnType ReadToArray(void *recvbuf_, size_t max_size) {
|
||||||
if (max_size == size_read) return true;
|
if (max_size == size_read) return kSuccess;
|
||||||
char *p = static_cast<char*>(recvbuf_);
|
char *p = static_cast<char*>(recvbuf_);
|
||||||
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
|
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
|
||||||
// length equals 0, remote disconnected
|
// length equals 0, remote disconnected
|
||||||
if (len == 0) {
|
if (len == 0) {
|
||||||
sock.Close(); return false;
|
sock.Close(); return kRecvZeroLen;
|
||||||
}
|
}
|
||||||
if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK;
|
if (len == -1) return Errno2Return(errno);
|
||||||
size_read += static_cast<size_t>(len);
|
size_read += static_cast<size_t>(len);
|
||||||
return true;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief write data in array to sock
|
* \brief write data in array to sock
|
||||||
@ -268,12 +318,12 @@ class AllreduceBase : public IEngine {
|
|||||||
* \param max_size maximum size of array
|
* \param max_size maximum size of array
|
||||||
* \return true if it is an successful write, false if there is some error happens, check errno
|
* \return true if it is an successful write, false if there is some error happens, check errno
|
||||||
*/
|
*/
|
||||||
inline bool WriteFromArray(const void *sendbuf_, size_t max_size) {
|
inline ReturnType WriteFromArray(const void *sendbuf_, size_t max_size) {
|
||||||
const char *p = static_cast<const char*>(sendbuf_);
|
const char *p = static_cast<const char*>(sendbuf_);
|
||||||
ssize_t len = sock.Send(p + size_write, max_size - size_write);
|
ssize_t len = sock.Send(p + size_write, max_size - size_write);
|
||||||
if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK;
|
if (len == -1) return Errno2Return(errno);
|
||||||
size_write += static_cast<size_t>(len);
|
size_write += static_cast<size_t>(len);
|
||||||
return true;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -333,6 +383,14 @@ class AllreduceBase : public IEngine {
|
|||||||
* \sa ReturnType
|
* \sa ReturnType
|
||||||
*/
|
*/
|
||||||
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
|
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
|
||||||
|
/*!
|
||||||
|
* \brief function used to report error when a link goes wrong
|
||||||
|
* \param link the pointer to the link who causes the error
|
||||||
|
* \param err the error type
|
||||||
|
*/
|
||||||
|
inline ReturnType ReportError(LinkRecord *link, ReturnType err) {
|
||||||
|
err_link = link; return err;
|
||||||
|
}
|
||||||
//---- data structure related to model ----
|
//---- data structure related to model ----
|
||||||
// call sequence counter, records how many calls we made so far
|
// call sequence counter, records how many calls we made so far
|
||||||
// from last call to CheckPoint, LoadCheckPoint
|
// from last call to CheckPoint, LoadCheckPoint
|
||||||
@ -348,6 +406,8 @@ class AllreduceBase : public IEngine {
|
|||||||
int parent_rank;
|
int parent_rank;
|
||||||
// sockets of all links this connects to
|
// sockets of all links this connects to
|
||||||
std::vector<LinkRecord> all_links;
|
std::vector<LinkRecord> all_links;
|
||||||
|
// used to record the link where things goes wrong
|
||||||
|
LinkRecord *err_link;
|
||||||
// all the links in the reduction tree connection
|
// all the links in the reduction tree connection
|
||||||
RefLinkVector tree_links;
|
RefLinkVector tree_links;
|
||||||
// pointer to links in the ring
|
// pointer to links in the ring
|
||||||
|
|||||||
@ -97,7 +97,9 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
|
|||||||
// exception handling
|
// exception handling
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
// recive OOB message from some link
|
// recive OOB message from some link
|
||||||
if (selecter.CheckExcept(links[i].sock)) return kGetExcept;
|
if (selecter.CheckExcept(links[i].sock)) {
|
||||||
|
return ReportError(&links[i], kGetExcept);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (stage == 0) {
|
if (stage == 0) {
|
||||||
bool finished = true;
|
bool finished = true;
|
||||||
@ -105,9 +107,8 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
|
|||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (i != parent_index) {
|
if (i != parent_index) {
|
||||||
if (selecter.CheckRead(links[i].sock)) {
|
if (selecter.CheckRead(links[i].sock)) {
|
||||||
if (!links[i].ReadToArray(&edge_in[i], sizeof(EdgeType))) {
|
ReturnType ret = links[i].ReadToArray(&edge_in[i], sizeof(EdgeType));
|
||||||
return kSockError;
|
if (ret != kSuccess) return ReportError(&links[i], ret);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (links[i].size_read != sizeof(EdgeType)) finished = false;
|
if (links[i].size_read != sizeof(EdgeType)) finished = false;
|
||||||
}
|
}
|
||||||
@ -128,17 +129,15 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
|
|||||||
if (stage == 1) {
|
if (stage == 1) {
|
||||||
const int pid = this->parent_index;
|
const int pid = this->parent_index;
|
||||||
utils::Assert(pid != -1, "MsgPassing invalid stage");
|
utils::Assert(pid != -1, "MsgPassing invalid stage");
|
||||||
if (!links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType))) {
|
ReturnType ret = links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType));
|
||||||
return kSockError;
|
if (ret != kSuccess) return ReportError(&links[pid], ret);
|
||||||
}
|
|
||||||
if (links[pid].size_write == sizeof(EdgeType)) stage = 2;
|
if (links[pid].size_write == sizeof(EdgeType)) stage = 2;
|
||||||
}
|
}
|
||||||
if (stage == 2) {
|
if (stage == 2) {
|
||||||
const int pid = this->parent_index;
|
const int pid = this->parent_index;
|
||||||
utils::Assert(pid != -1, "MsgPassing invalid stage");
|
utils::Assert(pid != -1, "MsgPassing invalid stage");
|
||||||
if (!links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType))) {
|
ReturnType ret = links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType));
|
||||||
return kSockError;
|
if (ret != kSuccess) return ReportError(&links[pid], ret);
|
||||||
}
|
|
||||||
if (links[pid].size_read == sizeof(EdgeType)) {
|
if (links[pid].size_read == sizeof(EdgeType)) {
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (i != pid) edge_out[i] = func(node_value, edge_in, i);
|
if (i != pid) edge_out[i] = func(node_value, edge_in, i);
|
||||||
@ -149,9 +148,8 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
|
|||||||
if (stage == 3) {
|
if (stage == 3) {
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
|
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
|
||||||
if (!links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType))) {
|
ReturnType ret = links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType));
|
||||||
return kSockError;
|
if (ret != kSuccess) return ReportError(&links[i], ret);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -25,6 +25,9 @@ AllreduceRobust::AllreduceRobust(void) {
|
|||||||
seq_counter = 0;
|
seq_counter = 0;
|
||||||
local_chkpt_version = 0;
|
local_chkpt_version = 0;
|
||||||
result_buffer_round = 1;
|
result_buffer_round = 1;
|
||||||
|
global_lazycheck = NULL;
|
||||||
|
use_local_model = -1;
|
||||||
|
recover_counter = 0;
|
||||||
}
|
}
|
||||||
void AllreduceRobust::Init(void) {
|
void AllreduceRobust::Init(void) {
|
||||||
AllreduceBase::Init();
|
AllreduceBase::Init();
|
||||||
@ -154,9 +157,7 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
|
|||||||
ISerializable *local_model) {
|
ISerializable *local_model) {
|
||||||
// skip action in single node
|
// skip action in single node
|
||||||
if (world_size == 1) return 0;
|
if (world_size == 1) return 0;
|
||||||
if (local_model != NULL && num_local_replica == 0) {
|
this->LocalModelCheck(local_model != NULL);
|
||||||
num_local_replica = default_local_replica;
|
|
||||||
}
|
|
||||||
if (num_local_replica == 0) {
|
if (num_local_replica == 0) {
|
||||||
utils::Check(local_model == NULL,
|
utils::Check(local_model == NULL,
|
||||||
"need to set rabit_local_replica larger than 1 to checkpoint local_model");
|
"need to set rabit_local_replica larger than 1 to checkpoint local_model");
|
||||||
@ -199,30 +200,50 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
* \brief internal consistency check function,
|
||||||
* every time we call check point, there is a version number which will increase by one
|
* use check to ensure user always call CheckPoint/LoadCheckPoint
|
||||||
|
* with or without local but not both, this function will set the approperiate settings
|
||||||
|
* in the first call of LoadCheckPoint/CheckPoint
|
||||||
*
|
*
|
||||||
|
* \param with_local whether the user calls CheckPoint with local model
|
||||||
|
*/
|
||||||
|
void AllreduceRobust::LocalModelCheck(bool with_local) {
|
||||||
|
if (use_local_model == -1) {
|
||||||
|
if (with_local) {
|
||||||
|
use_local_model = 1;
|
||||||
|
if (num_local_replica == 0) {
|
||||||
|
num_local_replica = default_local_replica;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
use_local_model = 0;
|
||||||
|
num_local_replica = 0;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
utils::Check(use_local_model == int(with_local),
|
||||||
|
"Can only call Checkpoint/LoadCheckPoint always with"\
|
||||||
|
"or without local_model, but not mixed case");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief internal implementation of checkpoint, support both lazy and normal way
|
||||||
|
*
|
||||||
* \param global_model pointer to the globally shared model/state
|
* \param global_model pointer to the globally shared model/state
|
||||||
* when calling this function, the caller need to gauranttees that global_model
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
* is the same in all nodes
|
* is the same in all nodes
|
||||||
* \param local_model pointer to local model, that is specific to current node/rank
|
* \param local_model pointer to local model, that is specific to current node/rank
|
||||||
* this can be NULL when no local state is needed
|
* this can be NULL when no local state is needed
|
||||||
|
* \param lazy_checkpt whether the action is lazy checkpoint
|
||||||
*
|
*
|
||||||
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
|
* \sa CheckPoint, LazyCheckPoint
|
||||||
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
|
|
||||||
* So only CheckPoint with global_model if possible
|
|
||||||
*
|
|
||||||
* \sa LoadCheckPoint, VersionNumber
|
|
||||||
*/
|
*/
|
||||||
void AllreduceRobust::CheckPoint(const ISerializable *global_model,
|
void AllreduceRobust::CheckPoint_(const ISerializable *global_model,
|
||||||
const ISerializable *local_model) {
|
const ISerializable *local_model,
|
||||||
|
bool lazy_checkpt) {
|
||||||
// never do check point in single machine mode
|
// never do check point in single machine mode
|
||||||
if (world_size == 1) {
|
if (world_size == 1) {
|
||||||
version_number += 1; return;
|
version_number += 1; return;
|
||||||
}
|
}
|
||||||
if (local_model != NULL && num_local_replica == 0) {
|
this->LocalModelCheck(local_model != NULL);
|
||||||
num_local_replica = default_local_replica;
|
|
||||||
}
|
|
||||||
if (num_local_replica == 0) {
|
if (num_local_replica == 0) {
|
||||||
utils::Check(local_model == NULL,
|
utils::Check(local_model == NULL,
|
||||||
"need to set rabit_local_replica larger than 1 to checkpoint local_model");
|
"need to set rabit_local_replica larger than 1 to checkpoint local_model");
|
||||||
@ -255,10 +276,15 @@ void AllreduceRobust::CheckPoint(const ISerializable *global_model,
|
|||||||
// increase version number
|
// increase version number
|
||||||
version_number += 1;
|
version_number += 1;
|
||||||
// save model
|
// save model
|
||||||
global_checkpoint.resize(0);
|
if (lazy_checkpt) {
|
||||||
utils::MemoryBufferStream fs(&global_checkpoint);
|
global_lazycheck = global_model;
|
||||||
fs.Write(&version_number, sizeof(version_number));
|
} else {
|
||||||
global_model->Save(fs);
|
global_checkpoint.resize(0);
|
||||||
|
utils::MemoryBufferStream fs(&global_checkpoint);
|
||||||
|
fs.Write(&version_number, sizeof(version_number));
|
||||||
|
global_model->Save(fs);
|
||||||
|
global_lazycheck = NULL;
|
||||||
|
}
|
||||||
// reset result buffer
|
// reset result buffer
|
||||||
resbuf.Clear(); seq_counter = 0;
|
resbuf.Clear(); seq_counter = 0;
|
||||||
// execute check ack step, load happens here
|
// execute check ack step, load happens here
|
||||||
@ -396,6 +422,8 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
|||||||
*/
|
*/
|
||||||
bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
|
bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
|
||||||
if (err_type == kSuccess) return true;
|
if (err_type == kSuccess) return true;
|
||||||
|
utils::Assert(err_link != NULL, "must know the error source");
|
||||||
|
recover_counter += 1;
|
||||||
{
|
{
|
||||||
// simple way, shutdown all links
|
// simple way, shutdown all links
|
||||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||||
@ -407,7 +435,7 @@ bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
|
|||||||
// this was old way
|
// this was old way
|
||||||
// TryResetLinks still causes possible errors, so not use this one
|
// TryResetLinks still causes possible errors, so not use this one
|
||||||
while (err_type != kSuccess) {
|
while (err_type != kSuccess) {
|
||||||
switch (err_type) {
|
switch (err_type.value) {
|
||||||
case kGetExcept: err_type = TryResetLinks(); break;
|
case kGetExcept: err_type = TryResetLinks(); break;
|
||||||
case kSockError: {
|
case kSockError: {
|
||||||
TryResetLinks();
|
TryResetLinks();
|
||||||
@ -577,6 +605,9 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
|||||||
if (!req_data) return kSuccess;
|
if (!req_data) return kSuccess;
|
||||||
}
|
}
|
||||||
utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active");
|
utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active");
|
||||||
|
if (role == kPassData) {
|
||||||
|
links[recv_link].InitBuffer(1, size, reduce_buffer_size);
|
||||||
|
}
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
links[i].ResetSize();
|
links[i].ResetSize();
|
||||||
}
|
}
|
||||||
@ -601,27 +632,33 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
|||||||
selecter.Select();
|
selecter.Select();
|
||||||
// exception handling
|
// exception handling
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (selecter.CheckExcept(links[i].sock)) return kGetExcept;
|
if (selecter.CheckExcept(links[i].sock)) {
|
||||||
|
return ReportError(&links[i], kGetExcept);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (role == kRequestData) {
|
if (role == kRequestData) {
|
||||||
const int pid = recv_link;
|
const int pid = recv_link;
|
||||||
if (selecter.CheckRead(links[pid].sock)) {
|
if (selecter.CheckRead(links[pid].sock)) {
|
||||||
if (!links[pid].ReadToArray(sendrecvbuf_, size)) return kSockError;
|
ReturnType ret = links[pid].ReadToArray(sendrecvbuf_, size);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
return ReportError(&links[pid], ret);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (req_in[i] && links[i].size_write != links[pid].size_read &&
|
if (req_in[i] && links[i].size_write != links[pid].size_read) {
|
||||||
selecter.CheckWrite(links[i].sock)) {
|
ReturnType ret = links[i].WriteFromArray(sendrecvbuf_, links[pid].size_read);
|
||||||
if (!links[i].WriteFromArray(sendrecvbuf_, links[pid].size_read)) {
|
if (ret != kSuccess) {
|
||||||
return kSockError;
|
return ReportError(&links[i], ret);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (role == kHaveData) {
|
if (role == kHaveData) {
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (req_in[i] && selecter.CheckWrite(links[i].sock)) {
|
if (req_in[i] && links[i].size_write != size) {
|
||||||
if (!links[i].WriteFromArray(sendrecvbuf_, size)) {
|
ReturnType ret = links[i].WriteFromArray(sendrecvbuf_, size);
|
||||||
return kSockError;
|
if (ret != kSuccess) {
|
||||||
|
return ReportError(&links[i], ret);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -635,11 +672,13 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
|||||||
if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
|
if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
|
||||||
}
|
}
|
||||||
utils::Assert(min_write <= links[pid].size_read, "boundary check");
|
utils::Assert(min_write <= links[pid].size_read, "boundary check");
|
||||||
if (!links[pid].ReadToRingBuffer(min_write)) return kSockError;
|
ReturnType ret = links[pid].ReadToRingBuffer(min_write);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
return ReportError(&links[pid], ret);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (req_in[i] && selecter.CheckWrite(links[i].sock) &&
|
if (req_in[i] && links[pid].size_read != links[i].size_write) {
|
||||||
links[pid].size_read != links[i].size_write) {
|
|
||||||
size_t start = links[i].size_write % buffer_size;
|
size_t start = links[i].size_write % buffer_size;
|
||||||
// send out data from ring buffer
|
// send out data from ring buffer
|
||||||
size_t nwrite = std::min(buffer_size - start, links[pid].size_read - links[i].size_write);
|
size_t nwrite = std::min(buffer_size - start, links[pid].size_read - links[i].size_write);
|
||||||
@ -647,7 +686,8 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
|||||||
if (len != -1) {
|
if (len != -1) {
|
||||||
links[i].size_write += len;
|
links[i].size_write += len;
|
||||||
} else {
|
} else {
|
||||||
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
|
ReturnType ret = Errno2Return(errno);
|
||||||
|
if (ret != kSuccess) return ReportError(&links[i], ret);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -698,6 +738,14 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
|
|||||||
utils::Check(state == 1 || state == 2,
|
utils::Check(state == 1 || state == 2,
|
||||||
"LoadCheckPoint: too many nodes fails, cannot recover local state");
|
"LoadCheckPoint: too many nodes fails, cannot recover local state");
|
||||||
}
|
}
|
||||||
|
// do call save model if the checkpoint was lazy
|
||||||
|
if (role == kHaveData && global_lazycheck != NULL) {
|
||||||
|
global_checkpoint.resize(0);
|
||||||
|
utils::MemoryBufferStream fs(&global_checkpoint);
|
||||||
|
fs.Write(&version_number, sizeof(version_number));
|
||||||
|
global_lazycheck->Save(fs);
|
||||||
|
global_lazycheck = NULL;
|
||||||
|
}
|
||||||
// recover global checkpoint
|
// recover global checkpoint
|
||||||
size_t size = this->global_checkpoint.length();
|
size_t size = this->global_checkpoint.length();
|
||||||
int recv_link;
|
int recv_link;
|
||||||
@ -1098,27 +1146,28 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
|
|||||||
selecter.WatchException(next.sock);
|
selecter.WatchException(next.sock);
|
||||||
if (finished) break;
|
if (finished) break;
|
||||||
selecter.Select();
|
selecter.Select();
|
||||||
if (selecter.CheckExcept(prev.sock)) return kGetExcept;
|
if (selecter.CheckExcept(prev.sock)) return ReportError(&prev, kGetExcept);
|
||||||
if (selecter.CheckExcept(next.sock)) return kGetExcept;
|
if (selecter.CheckExcept(next.sock)) return ReportError(&next, kGetExcept);
|
||||||
if (read_ptr != read_end && selecter.CheckRead(prev.sock)) {
|
if (read_ptr != read_end && selecter.CheckRead(prev.sock)) {
|
||||||
ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr);
|
ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr);
|
||||||
if (len == 0) {
|
if (len == 0) {
|
||||||
prev.sock.Close(); return kSockError;
|
prev.sock.Close(); return ReportError(&prev, kRecvZeroLen);
|
||||||
}
|
}
|
||||||
if (len != -1) {
|
if (len != -1) {
|
||||||
read_ptr += static_cast<size_t>(len);
|
read_ptr += static_cast<size_t>(len);
|
||||||
} else {
|
} else {
|
||||||
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
|
ReturnType ret = Errno2Return(errno);
|
||||||
|
if (ret != kSuccess) return ReportError(&prev, ret);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (write_ptr != write_end && write_ptr < read_ptr &&
|
if (write_ptr != write_end && write_ptr < read_ptr) {
|
||||||
selecter.CheckWrite(next.sock)) {
|
|
||||||
size_t nsend = std::min(write_end - write_ptr, read_ptr - write_ptr);
|
size_t nsend = std::min(write_end - write_ptr, read_ptr - write_ptr);
|
||||||
ssize_t len = next.sock.Send(buf + write_ptr, nsend);
|
ssize_t len = next.sock.Send(buf + write_ptr, nsend);
|
||||||
if (len != -1) {
|
if (len != -1) {
|
||||||
write_ptr += static_cast<size_t>(len);
|
write_ptr += static_cast<size_t>(len);
|
||||||
} else {
|
} else {
|
||||||
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
|
ReturnType ret = Errno2Return(errno);
|
||||||
|
if (ret != kSuccess) return ReportError(&prev, ret);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -99,7 +99,32 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
* \sa LoadCheckPoint, VersionNumber
|
* \sa LoadCheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual void CheckPoint(const ISerializable *global_model,
|
virtual void CheckPoint(const ISerializable *global_model,
|
||||||
const ISerializable *local_model = NULL);
|
const ISerializable *local_model = NULL) {
|
||||||
|
this->CheckPoint_(global_model, local_model, false);
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||||
|
* when certain condition is met(see detailed expplaination).
|
||||||
|
*
|
||||||
|
* This is a "lazy" checkpoint such that only the pointer to global_model is
|
||||||
|
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
|
||||||
|
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
|
||||||
|
* In another words, global_model model can be changed only between last call of
|
||||||
|
* Allreduce/Broadcast and LazyCheckPoint in current version
|
||||||
|
*
|
||||||
|
* For example, suppose the calling sequence is:
|
||||||
|
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
|
||||||
|
*
|
||||||
|
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
|
||||||
|
* improve efficiency of the program.
|
||||||
|
* \param global_model pointer to the globally shared model/state
|
||||||
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
|
* is the same in all nodes
|
||||||
|
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||||
|
*/
|
||||||
|
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||||
|
this->CheckPoint_(global_model, NULL, true);
|
||||||
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief explicitly re-init everything before calling LoadCheckPoint
|
* \brief explicitly re-init everything before calling LoadCheckPoint
|
||||||
* call this function when IEngine throw an exception out,
|
* call this function when IEngine throw an exception out,
|
||||||
@ -274,10 +299,38 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
std::vector<uint64_t> data_;
|
std::vector<uint64_t> data_;
|
||||||
};
|
};
|
||||||
/*!
|
/*!
|
||||||
* \brief reset the all the existing links by sending Out-of-Band message marker
|
* \brief internal consistency check function,
|
||||||
* after this function finishes, all the messages received and sent before in all live links are discarded,
|
* use check to ensure user always call CheckPoint/LoadCheckPoint
|
||||||
* This allows us to get a fresh start after error has happened
|
* with or without local but not both, this function will set the approperiate settings
|
||||||
|
* in the first call of LoadCheckPoint/CheckPoint
|
||||||
*
|
*
|
||||||
|
* \param with_local whether the user calls CheckPoint with local model
|
||||||
|
*/
|
||||||
|
void LocalModelCheck(bool with_local);
|
||||||
|
/*!
|
||||||
|
* \brief internal implementation of checkpoint, support both lazy and normal way
|
||||||
|
*
|
||||||
|
* \param global_model pointer to the globally shared model/state
|
||||||
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
|
* is the same in all nodes
|
||||||
|
* \param local_model pointer to local model, that is specific to current node/rank
|
||||||
|
* this can be NULL when no local state is needed
|
||||||
|
* \param lazy_checkpt whether the action is lazy checkpoint
|
||||||
|
*
|
||||||
|
* \sa CheckPoint, LazyCheckPoint
|
||||||
|
*/
|
||||||
|
void CheckPoint_(const ISerializable *global_model,
|
||||||
|
const ISerializable *local_model,
|
||||||
|
bool lazy_checkpt);
|
||||||
|
/*!
|
||||||
|
* \brief reset the all the existing links by sending Out-of-Band message marker
|
||||||
|
* after this function finishes, all the messages received and sent
|
||||||
|
* before in all live links are discarded,
|
||||||
|
* This allows us to get a fresh start after error has happened
|
||||||
|
*
|
||||||
|
* TODO(tqchen): this function is not yet functioning was not used by engine,
|
||||||
|
* simple resetlink and reconnect strategy is used
|
||||||
|
*
|
||||||
* \return this function can return kSuccess or kSockError
|
* \return this function can return kSuccess or kSockError
|
||||||
* when kSockError is returned, it simply means there are bad sockets in the links,
|
* when kSockError is returned, it simply means there are bad sockets in the links,
|
||||||
* and some link recovery proceduer is needed
|
* and some link recovery proceduer is needed
|
||||||
@ -468,12 +521,18 @@ o * the input state must exactly one saved state(local state of current node)
|
|||||||
ResultBuffer resbuf;
|
ResultBuffer resbuf;
|
||||||
// last check point global model
|
// last check point global model
|
||||||
std::string global_checkpoint;
|
std::string global_checkpoint;
|
||||||
|
// lazy checkpoint of global model
|
||||||
|
const ISerializable *global_lazycheck;
|
||||||
// number of replica for local state/model
|
// number of replica for local state/model
|
||||||
int num_local_replica;
|
int num_local_replica;
|
||||||
// number of default local replica
|
// number of default local replica
|
||||||
int default_local_replica;
|
int default_local_replica;
|
||||||
|
// flag to decide whether local model is used, -1: unknown, 0: no, 1:yes
|
||||||
|
int use_local_model;
|
||||||
// number of replica for global state/model
|
// number of replica for global state/model
|
||||||
int num_global_replica;
|
int num_global_replica;
|
||||||
|
// number of times recovery happens
|
||||||
|
int recover_counter;
|
||||||
// --- recovery data structure for local checkpoint
|
// --- recovery data structure for local checkpoint
|
||||||
// there is two version of the data structure,
|
// there is two version of the data structure,
|
||||||
// at one time one version is valid and another is used as temp memory
|
// at one time one version is valid and another is used as temp memory
|
||||||
|
|||||||
@ -56,7 +56,8 @@ void Allreduce_(void *sendrecvbuf,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// code for reduce handle
|
// code for reduce handle
|
||||||
ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) {
|
ReduceHandle::ReduceHandle(void)
|
||||||
|
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
|
||||||
}
|
}
|
||||||
ReduceHandle::~ReduceHandle(void) {}
|
ReduceHandle::~ReduceHandle(void) {}
|
||||||
|
|
||||||
@ -64,17 +65,16 @@ int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
|||||||
return static_cast<int>(dtype.type_size);
|
return static_cast<int>(dtype.type_size);
|
||||||
}
|
}
|
||||||
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
|
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
|
||||||
utils::Assert(handle_ == NULL, "cannot initialize reduce handle twice");
|
utils::Assert(redfunc_ == NULL, "cannot initialize reduce handle twice");
|
||||||
handle_ = reinterpret_cast<void*>(redfunc);
|
redfunc_ = redfunc;
|
||||||
}
|
}
|
||||||
void ReduceHandle::Allreduce(void *sendrecvbuf,
|
void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||||
size_t type_nbytes, size_t count,
|
size_t type_nbytes, size_t count,
|
||||||
IEngine::PreprocFunction prepare_fun,
|
IEngine::PreprocFunction prepare_fun,
|
||||||
void *prepare_arg) {
|
void *prepare_arg) {
|
||||||
utils::Assert(handle_ != NULL, "must intialize handle to call AllReduce");
|
utils::Assert(redfunc_ != NULL, "must intialize handle to call AllReduce");
|
||||||
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
|
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
|
||||||
reinterpret_cast<IEngine::ReduceFunction*>(handle_),
|
redfunc_, prepare_fun, prepare_arg);
|
||||||
prepare_fun, prepare_arg);
|
|
||||||
}
|
}
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|||||||
@ -42,6 +42,9 @@ class EmptyEngine : public IEngine {
|
|||||||
const ISerializable *local_model = NULL) {
|
const ISerializable *local_model = NULL) {
|
||||||
version_number += 1;
|
version_number += 1;
|
||||||
}
|
}
|
||||||
|
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||||
|
version_number += 1;
|
||||||
|
}
|
||||||
virtual int VersionNumber(void) const {
|
virtual int VersionNumber(void) const {
|
||||||
return version_number;
|
return version_number;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -45,6 +45,9 @@ class MPIEngine : public IEngine {
|
|||||||
const ISerializable *local_model = NULL) {
|
const ISerializable *local_model = NULL) {
|
||||||
version_number += 1;
|
version_number += 1;
|
||||||
}
|
}
|
||||||
|
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||||
|
version_number += 1;
|
||||||
|
}
|
||||||
virtual int VersionNumber(void) const {
|
virtual int VersionNumber(void) const {
|
||||||
return version_number;
|
return version_number;
|
||||||
}
|
}
|
||||||
@ -134,7 +137,8 @@ void Allreduce_(void *sendrecvbuf,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// code for reduce handle
|
// code for reduce handle
|
||||||
ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) {
|
ReduceHandle::ReduceHandle(void)
|
||||||
|
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
|
||||||
}
|
}
|
||||||
ReduceHandle::~ReduceHandle(void) {
|
ReduceHandle::~ReduceHandle(void) {
|
||||||
if (handle_ != NULL) {
|
if (handle_ != NULL) {
|
||||||
|
|||||||
14
test/test.mk
14
test/test.mk
@ -1,13 +1,7 @@
|
|||||||
ifndef $(nslave)
|
|
||||||
nslave=2
|
|
||||||
endif
|
|
||||||
ifndef $(ndata)
|
|
||||||
ndata=10
|
|
||||||
endif
|
|
||||||
|
|
||||||
# this is a makefile used to show testcases of rabit
|
# this is a makefile used to show testcases of rabit
|
||||||
.PHONY: model_recover local_recover speed
|
.PHONY:all
|
||||||
|
|
||||||
|
all:
|
||||||
|
|
||||||
# this experiment test recovery with actually process exit, use keepalive to keep program alive
|
# this experiment test recovery with actually process exit, use keepalive to keep program alive
|
||||||
model_recover_10_10k:
|
model_recover_10_10k:
|
||||||
@ -18,3 +12,7 @@ model_recover_10_10k_die_same:
|
|||||||
|
|
||||||
model_recover_10_10k_die_hard:
|
model_recover_10_10k_die_hard:
|
||||||
../tracker/rabit_demo.py -n 10 test_model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0
|
../tracker/rabit_demo.py -n 10 test_model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0
|
||||||
|
|
||||||
|
|
||||||
|
local_recover_10_10k:
|
||||||
|
../tracker/rabit_demo.py -n 10 test_local_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=1,1,1,1
|
||||||
|
|||||||
@ -20,21 +20,41 @@ parser.add_argument('command', nargs='+',
|
|||||||
help = 'command for rabit program')
|
help = 'command for rabit program')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# bash script for keepalive
|
||||||
|
# use it so that python do not need to communicate with subprocess
|
||||||
|
echo="echo %s rabit_num_trial=$nrep;"
|
||||||
|
keepalive = """
|
||||||
|
nrep=0
|
||||||
|
rc=254
|
||||||
|
while [ $rc -eq 254 ];
|
||||||
|
do
|
||||||
|
%s
|
||||||
|
%s %s rabit_num_trial=$nrep
|
||||||
|
rc=$?;
|
||||||
|
nrep=$((nrep+1));
|
||||||
|
done
|
||||||
|
"""
|
||||||
|
|
||||||
def exec_cmd(cmd, taskid):
|
def exec_cmd(cmd, taskid):
|
||||||
if cmd[0].find('/') == -1 and os.path.exists(cmd[0]):
|
if cmd[0].find('/') == -1 and os.path.exists(cmd[0]):
|
||||||
cmd[0] = './' + cmd[0]
|
cmd[0] = './' + cmd[0]
|
||||||
cmd = ' '.join(cmd)
|
cmd = ' '.join(cmd)
|
||||||
|
arg = ' rabit_task_id=%d' % (taskid)
|
||||||
|
cmd = cmd + arg
|
||||||
ntrial = 0
|
ntrial = 0
|
||||||
while True:
|
while True:
|
||||||
prep = 'PYTHONPATH=\"%s\" ' % WRAPPER_PATH
|
prep = 'PYTHONPATH=\"%s\" ' % WRAPPER_PATH
|
||||||
arg = ' rabit_task_id=%d rabit_num_trial=%d' % (taskid, ntrial)
|
if args.verbose != 0:
|
||||||
ret = subprocess.call(prep + cmd + arg, shell = True)
|
bash = keepalive % (echo % cmd, prep, cmd)
|
||||||
if ret == 254 or ret == -2:
|
else:
|
||||||
ntrial += 1
|
bash = keepalive % ('', prep, cmd)
|
||||||
continue
|
ret = subprocess.call(bash, shell=True, executable='bash')
|
||||||
if ret == 0:
|
if ret == 0:
|
||||||
|
if args.verbose != 0:
|
||||||
|
print 'Thread %d exit with 0' % taskid
|
||||||
return
|
return
|
||||||
raise Exception('Get nonzero return code=%d' % ret)
|
else:
|
||||||
|
raise Exception('Get nonzero return code=%d' % ret)
|
||||||
#
|
#
|
||||||
# Note: this submit script is only used for demo purpose
|
# Note: this submit script is only used for demo purpose
|
||||||
# submission script using pyhton multi-threading
|
# submission script using pyhton multi-threading
|
||||||
@ -51,6 +71,7 @@ def mthread_submit(nslave, worker_args):
|
|||||||
procs = {}
|
procs = {}
|
||||||
for i in range(nslave):
|
for i in range(nslave):
|
||||||
procs[i] = Thread(target = exec_cmd, args = (args.command + worker_args, i))
|
procs[i] = Thread(target = exec_cmd, args = (args.command + worker_args, i))
|
||||||
|
procs[i].daemon = True
|
||||||
procs[i].start()
|
procs[i].start()
|
||||||
for i in range(nslave):
|
for i in range(nslave):
|
||||||
procs[i].join()
|
procs[i].join()
|
||||||
|
|||||||
@ -257,6 +257,7 @@ class Tracker:
|
|||||||
def submit(nslave, args, fun_submit, verbose):
|
def submit(nslave, args, fun_submit, verbose):
|
||||||
master = Tracker(verbose = verbose)
|
master = Tracker(verbose = verbose)
|
||||||
submit_thread = Thread(target = fun_submit, args = (nslave, args + master.slave_args()))
|
submit_thread = Thread(target = fun_submit, args = (nslave, args + master.slave_args()))
|
||||||
|
submit_thread.daemon = True
|
||||||
submit_thread.start()
|
submit_thread.start()
|
||||||
master.accept_slaves(nslave)
|
master.accept_slaves(nslave)
|
||||||
submit_thread.join()
|
submit_thread.join()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user