diff --git a/src/engine_robust.cc b/src/engine_robust.cc index 9f03bea5e..6b820f98a 100644 --- a/src/engine_robust.cc +++ b/src/engine_robust.cc @@ -13,6 +13,9 @@ #include "./engine_robust.h" namespace engine { +AllReduceRobust::AllReduceRobust(void) { + result_buffer_round = 1; +} /*! * \brief perform in-place allreduce, on sendrecvbuf * this function is NOT thread-safe @@ -23,18 +26,29 @@ namespace engine { */ void AllReduceRobust::AllReduce(void *sendrecvbuf_, size_t type_nbytes, - size_t count, + size_t count, ReduceFunction reducer) { - while (true) { - ReturnType ret = TryAllReduce(sendrecvbuf_, type_nbytes, count, reducer); - if (ret == kSuccess) return; - if (ret == kSockError) { - utils::Error("error occur during all reduce\n"); - } - utils::LogPrintf("[%d] receive except signal, start reset link\n", rank); - TryResetLinks(); + bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter); + // now we are free to remove the last result, if any + if (resbuf.LastSeqNo() != -1 && + (resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) { + resbuf.DropLast(); } - // TODO + void *temp = resbuf.AllocTemp(type_nbytes, count); + while (true) { + if (recovered) { + std::memcpy(temp, sendrecvbuf_, type_nbytes * count); break; + } else { + std::memcpy(temp, sendrecvbuf_, type_nbytes * count); + if (CheckAndRecover(TryAllReduce(temp, type_nbytes, count, reducer))) { + std::memcpy(sendrecvbuf_, temp, type_nbytes * count); break; + } else { + recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter); + } + } + } + resbuf.PushTemp(seq_counter, type_nbytes, count); + seq_counter += 1; } /*! * \brief broadcast data from root to all nodes @@ -329,7 +343,6 @@ AllReduceRobust::TryDecideRouting(AllReduceRobust::RecoverType role, *p_recvlink = best_link; return kSuccess; } - /*! * \brief try to finish the data recovery request, * this function is used together with TryDecideRouting @@ -417,7 +430,7 @@ AllReduceRobust::TryRecoverData(RecoverType role, if (req_in[i]) min_write = std::min(links[i].size_write, min_write); } utils::Assert(min_write <= links[pid].size_read, "boundary check"); - if (!links[pid].ReadToRingBuffer(min_write)) return kSockError; + if (!links[pid].ReadToRingBuffer(min_write)) return kSockError; } for (int i = 0; i < nlink; ++i) { if (req_in[i] && selecter.CheckWrite(links[i].sock)) { @@ -438,7 +451,7 @@ AllReduceRobust::TryRecoverData(RecoverType role, } /*! * \brief try to load check point - * + * * This is a collaborative function called by all nodes * only the nodes with requester set to true really needs to load the check point * other nodes acts as collaborative roles to complete this request @@ -448,8 +461,17 @@ AllReduceRobust::TryRecoverData(RecoverType role, * \sa ReturnType */ AllReduceRobust::ReturnType AllReduceRobust::TryLoadCheckPoint(bool requester) { - - return kSuccess; + RecoverType role = requester ? kRequestData : kHaveData; + size_t size = this->checked_model.length(); + int recv_link; + std::vector req_in; + ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in); + if (succ != kSuccess) return succ; + if (role == kRequestData) { + checked_model.resize(size); + } + utils::Check(size != 0, "zero size check point is not allowed"); + return TryRecoverData(role, &checked_model[0], size, recv_link, req_in); } /*! * \brief try to get the result of operation specified by seqno @@ -458,17 +480,27 @@ AllReduceRobust::ReturnType AllReduceRobust::TryLoadCheckPoint(bool requester) { * only the nodes with requester set to true really needs to get the result * other nodes acts as collaborative roles to complete this request * - * \param buf the buffer to store the result, this parameter is only use when current node is requester - * \param size the total size of the buffer, this parameter is only use when current node is requester + * \param buf the buffer to store the result, this parameter is only used when current node is requester + * \param size the total size of the buffer, this parameter is only used when current node is requester * \param seqno sequence number of the operation, this is unique index of a operation in current iteration * \param requester whether current node is the requester * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details * \sa ReturnType */ AllReduceRobust::ReturnType -AllReduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) { - utils::Error("TryGetResult: not implemented"); - return kSuccess; +AllReduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) { RecoverType role; + if (!requester) { + sendrecvbuf = resbuf.Query(seqno, &size); + role = sendrecvbuf != NULL ? kHaveData : kPassData; + } else { + role = kRequestData; + } + int recv_link; + std::vector req_in; + ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in); + if (succ != kSuccess) return succ; + utils::Check(size != 0, "zero size check point is not allowed"); + return TryRecoverData(role, sendrecvbuf, size, recv_link, req_in); } /*! * \brief try to run recover execution for a request action described by flag and seqno, diff --git a/src/engine_robust.h b/src/engine_robust.h index be9cf0998..92febdd70 100644 --- a/src/engine_robust.h +++ b/src/engine_robust.h @@ -16,7 +16,8 @@ namespace engine { /*! \brief implementation of fault tolerant all reduce engine */ class AllReduceRobust : public AllReduceBase { - public: + public: + AllReduceRobust(void); virtual ~AllReduceRobust(void) {} /*! * \brief perform in-place allreduce, on sendrecvbuf @@ -178,6 +179,19 @@ class AllReduceRobust : public AllReduceBase { if (idx == seqno_.size() || seqno_[idx] != seqid) return NULL; *p_size = size_[idx]; return BeginPtr(data_) + rptr_[idx]; + } + // drop last stored result + inline void DropLast(void) { + utils::Assert(seqno_.size() != 0, "there is nothing to be dropped"); + seqno_.pop_back(); + rptr_.pop_back(); + size_.pop_back(); + data_.resize(rptr_.back()); + } + // the sequence number of last stored result + inline int LastSeqNo(void) const { + if (seqno_.size() == 0) return -1; + return seqno_.back(); } private: // sequence number of each @@ -248,8 +262,8 @@ class AllReduceRobust : public AllReduceBase { * only the nodes with requester set to true really needs to get the result * other nodes acts as collaborative roles to complete this request * - * \param buf the buffer to store the result, this parameter is only use when current node is requester - * \param size the total size of the buffer, this parameter is only use when current node is requester + * \param buf the buffer to store the result, this parameter is only used when current node is requester + * \param size the total size of the buffer, this parameter is only used when current node is requester * \param seqno sequence number of the operation, this is unique index of a operation in current iteration * \param requester whether current node is the requester * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details @@ -325,8 +339,13 @@ class AllReduceRobust : public AllReduceBase { // call sequence counter, records how many calls we made so far // from last call to CheckPoint, LoadCheckPoint int seq_counter; + // the round of result buffer, used to mode the result + int result_buffer_round; // result buffer ResultBuffer resbuf; + // last check point model + std::string checked_model; + }; } // namespace engine // implementation of inline template function