checkin allreduce recover

This commit is contained in:
tqchen 2014-11-30 22:41:04 -08:00
parent 9355f5faf2
commit 16f729115e
2 changed files with 74 additions and 23 deletions

View File

@ -13,6 +13,9 @@
#include "./engine_robust.h"
namespace engine {
AllReduceRobust::AllReduceRobust(void) {
result_buffer_round = 1;
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe
@ -25,16 +28,27 @@ void AllReduceRobust::AllReduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
// now we are free to remove the last result, if any
if (resbuf.LastSeqNo() != -1 &&
(resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
resbuf.DropLast();
}
void *temp = resbuf.AllocTemp(type_nbytes, count);
while (true) {
ReturnType ret = TryAllReduce(sendrecvbuf_, type_nbytes, count, reducer);
if (ret == kSuccess) return;
if (ret == kSockError) {
utils::Error("error occur during all reduce\n");
if (recovered) {
std::memcpy(temp, sendrecvbuf_, type_nbytes * count); break;
} else {
std::memcpy(temp, sendrecvbuf_, type_nbytes * count);
if (CheckAndRecover(TryAllReduce(temp, type_nbytes, count, reducer))) {
std::memcpy(sendrecvbuf_, temp, type_nbytes * count); break;
} else {
recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
}
utils::LogPrintf("[%d] receive except signal, start reset link\n", rank);
TryResetLinks();
}
// TODO
}
resbuf.PushTemp(seq_counter, type_nbytes, count);
seq_counter += 1;
}
/*!
* \brief broadcast data from root to all nodes
@ -329,7 +343,6 @@ AllReduceRobust::TryDecideRouting(AllReduceRobust::RecoverType role,
*p_recvlink = best_link;
return kSuccess;
}
/*!
* \brief try to finish the data recovery request,
* this function is used together with TryDecideRouting
@ -448,8 +461,17 @@ AllReduceRobust::TryRecoverData(RecoverType role,
* \sa ReturnType
*/
AllReduceRobust::ReturnType AllReduceRobust::TryLoadCheckPoint(bool requester) {
return kSuccess;
RecoverType role = requester ? kRequestData : kHaveData;
size_t size = this->checked_model.length();
int recv_link;
std::vector<bool> req_in;
ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in);
if (succ != kSuccess) return succ;
if (role == kRequestData) {
checked_model.resize(size);
}
utils::Check(size != 0, "zero size check point is not allowed");
return TryRecoverData(role, &checked_model[0], size, recv_link, req_in);
}
/*!
* \brief try to get the result of operation specified by seqno
@ -458,17 +480,27 @@ AllReduceRobust::ReturnType AllReduceRobust::TryLoadCheckPoint(bool requester) {
* only the nodes with requester set to true really needs to get the result
* other nodes acts as collaborative roles to complete this request
*
* \param buf the buffer to store the result, this parameter is only use when current node is requester
* \param size the total size of the buffer, this parameter is only use when current node is requester
* \param buf the buffer to store the result, this parameter is only used when current node is requester
* \param size the total size of the buffer, this parameter is only used when current node is requester
* \param seqno sequence number of the operation, this is unique index of a operation in current iteration
* \param requester whether current node is the requester
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType
*/
AllReduceRobust::ReturnType
AllReduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) {
utils::Error("TryGetResult: not implemented");
return kSuccess;
AllReduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) { RecoverType role;
if (!requester) {
sendrecvbuf = resbuf.Query(seqno, &size);
role = sendrecvbuf != NULL ? kHaveData : kPassData;
} else {
role = kRequestData;
}
int recv_link;
std::vector<bool> req_in;
ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in);
if (succ != kSuccess) return succ;
utils::Check(size != 0, "zero size check point is not allowed");
return TryRecoverData(role, sendrecvbuf, size, recv_link, req_in);
}
/*!
* \brief try to run recover execution for a request action described by flag and seqno,

View File

@ -17,6 +17,7 @@ namespace engine {
/*! \brief implementation of fault tolerant all reduce engine */
class AllReduceRobust : public AllReduceBase {
public:
AllReduceRobust(void);
virtual ~AllReduceRobust(void) {}
/*!
* \brief perform in-place allreduce, on sendrecvbuf
@ -179,6 +180,19 @@ class AllReduceRobust : public AllReduceBase {
*p_size = size_[idx];
return BeginPtr(data_) + rptr_[idx];
}
// drop last stored result
inline void DropLast(void) {
utils::Assert(seqno_.size() != 0, "there is nothing to be dropped");
seqno_.pop_back();
rptr_.pop_back();
size_.pop_back();
data_.resize(rptr_.back());
}
// the sequence number of last stored result
inline int LastSeqNo(void) const {
if (seqno_.size() == 0) return -1;
return seqno_.back();
}
private:
// sequence number of each
std::vector<int> seqno_;
@ -248,8 +262,8 @@ class AllReduceRobust : public AllReduceBase {
* only the nodes with requester set to true really needs to get the result
* other nodes acts as collaborative roles to complete this request
*
* \param buf the buffer to store the result, this parameter is only use when current node is requester
* \param size the total size of the buffer, this parameter is only use when current node is requester
* \param buf the buffer to store the result, this parameter is only used when current node is requester
* \param size the total size of the buffer, this parameter is only used when current node is requester
* \param seqno sequence number of the operation, this is unique index of a operation in current iteration
* \param requester whether current node is the requester
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
@ -325,8 +339,13 @@ class AllReduceRobust : public AllReduceBase {
// call sequence counter, records how many calls we made so far
// from last call to CheckPoint, LoadCheckPoint
int seq_counter;
// the round of result buffer, used to mode the result
int result_buffer_round;
// result buffer
ResultBuffer resbuf;
// last check point model
std::string checked_model;
};
} // namespace engine
// implementation of inline template function