diff --git a/src/engine_robust.cc b/src/engine_robust.cc index cd393f445..44497bfbe 100644 --- a/src/engine_robust.cc +++ b/src/engine_robust.cc @@ -9,6 +9,7 @@ #define NOMINMAX #include #include +#include "./io.h" #include "./utils.h" #include "./engine_robust.h" @@ -16,6 +17,7 @@ namespace rabit { namespace engine { AllReduceRobust::AllReduceRobust(void) { result_buffer_round = 1; + seq_counter = 0; } /*! * \brief perform in-place allreduce, on sendrecvbuf @@ -58,9 +60,26 @@ void AllReduceRobust::AllReduce(void *sendrecvbuf_, * \param root the root worker id to broadcast the data */ void AllReduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) { - utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess, - "AllReduce failed"); - // TODO + bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter); + // now we are free to remove the last result, if any + if (resbuf.LastSeqNo() != -1 && + (resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) { + resbuf.DropLast(); + } + void *temp = resbuf.AllocTemp(1, total_size); + while (true) { + if (recovered) { + std::memcpy(temp, sendrecvbuf_, total_size); break; + } else { + if (CheckAndRecover(TryBroadcast(sendrecvbuf_, total_size, root))) { + std::memcpy(temp, sendrecvbuf_, total_size); break; + } else { + recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter); + } + } + } + resbuf.PushTemp(seq_counter, 1, total_size); + seq_counter += 1; } /*! * \brief load latest check point @@ -69,15 +88,43 @@ void AllReduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) * false if there was no stored checkpoint, means we are start over gain */ bool AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) { - // TODO - return false; + // check if we succesfll + if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kMaxSeq)) { + // if loaded model is empty, this simply means we did not call checkpoint yet + // ask caller to reinit model + if (checked_model.length() == 0) return false; + // load from buffer + utils::MemoryBufferStream fs(&checked_model); + p_model->Load(fs); + // reset result buffer + resbuf.Clear(); seq_counter = 0; + // run another phase of check ack, if recovered from data + utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq), + "check ack must return true"); + return true; + } else { + // nothing loaded, a fresh start, everyone init model + return false; + } } /*! * \brief checkpoint the model, meaning we finished a stage of execution * \param p_model pointer to the model */ void AllReduceRobust::CheckPoint(const utils::ISerializable &model) { - // TODO + // save model + checked_model.resize(0); + utils::MemoryBufferStream fs(&checked_model); + model.Save(fs); + utils::Check(checked_model.length() != 0, "CheckPoint: empty model, model.Save must save something"); + // execute checkpoint, note: when checkpoint existing, load will not happen + utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq), + "check point must return true"); + // reset result buffer + resbuf.Clear(); seq_counter = 0; + // execute check ack step, load happens here + utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq), + "check ack must return true"); } /*! * \brief reset the all the existing links by sending Out-of-Band message marker diff --git a/src/engine_robust.h b/src/engine_robust.h index 0dbf31852..703a54469 100644 --- a/src/engine_robust.h +++ b/src/engine_robust.h @@ -154,7 +154,7 @@ class AllReduceRobust : public AllReduceBase { rptr_.clear(); rptr_.push_back(0); data_.clear(); } - // allocate temporal space for + // allocate temporal space inline void *AllocTemp(size_t type_nbytes, size_t count) { size_t size = type_nbytes * count; size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t); diff --git a/src/io.h b/src/io.h index 913acaa9a..ed01545f2 100644 --- a/src/io.h +++ b/src/io.h @@ -91,6 +91,7 @@ class IStream { /*! \brief interface of se*/ class ISerializable { + public: /*! \brief load the model from file */ virtual void Load(IStream &fi) = 0; /*! \brief save the model to the stream*/