From cc410b8c90f035846ba957fd313f91b4a1f1fdba Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 4 Dec 2014 11:09:15 -0800 Subject: [PATCH] add local model in checkpoint interface, a new goal --- src/allreduce_base.cc | 2 +- src/allreduce_base.h | 35 ++++++++-- src/allreduce_robust.cc | 139 +++++++++++++++++++++++++++++++++++----- src/allreduce_robust.h | 76 +++++++++++++++++++--- src/engine.h | 24 +++++-- src/engine_mpi.cc | 6 +- src/rabit-inl.h | 2 +- 7 files changed, 248 insertions(+), 36 deletions(-) diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 4c30b62d2..eba06a504 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -348,7 +348,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_, } if (len != -1) { size_down_in += static_cast(len); - utils::Assert(size_down_in <= size_up_out, "Allreduce: boundary error"); + utils::Assert(size_down_in <= size_up_out, "Allreduce: boundary error"); } else { if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError; } diff --git a/src/allreduce_base.h b/src/allreduce_base.h index cd9a5b0d0..29e05f8e5 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -84,23 +84,48 @@ class AllreduceBase : public IEngine { } /*! * \brief load latest check point - * \param p_model pointer to the model + * \param global_model pointer to the globally shared model/state + * when calling this function, the caller need to gauranttees that global_model + * is the same in all nodes + * \param local_model pointer to local model, that is specific to current node/rank + * this can be NULL when no local model is needed + * * \return the version number of check point loaded * if returned version == 0, this means no model has been CheckPointed * the p_model is not touched, user should do necessary initialization by themselves + * + * Common usage example: + * int iter = rabit::LoadCheckPoint(&model); + * if (iter == 0) model.InitParameters(); + * for (i = iter; i < max_iter; ++i) { + * do many things, include allreduce + * rabit::CheckPoint(model); + * } + * * \sa CheckPoint, VersionNumber */ - virtual int LoadCheckPoint(utils::ISerializable *p_model) { + virtual int LoadCheckPoint(utils::ISerializable *global_model, + utils::ISerializable *local_model = NULL) { return 0; } /*! * \brief checkpoint the model, meaning we finished a stage of execution * every time we call check point, there is a version number which will increase by one * - * \param p_model pointer to the model + * \param global_model pointer to the globally shared model/state + * when calling this function, the caller need to gauranttees that global_model + * is the same in all nodes + * \param local_model pointer to local model, that is specific to current node/rank + * this can be NULL when no local state is needed + * + * NOTE: local_model requires explicit replication of the model for fault-tolerance, which will + * bring replication cost in CheckPoint function. global_model do not need explicit replication. + * So only CheckPoint with global_model if possible + * * \sa LoadCheckPoint, VersionNumber */ - virtual void CheckPoint(const utils::ISerializable &model) { + virtual void CheckPoint(const utils::ISerializable *global_model, + const utils::ISerializable *local_model = NULL) { version_number += 1; } /*! @@ -267,6 +292,8 @@ class AllreduceBase : public IEngine { int parent_rank; // sockets of all links std::vector links; + // pointer to someplace in the ring + LinkRecord *ring_prev, *ring_next; //----- meta information----- // unique identifier of the possible job this process is doing // used to assign ranks, optional, default to NULL diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index 6aba63e82..a878f5618 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -17,6 +17,7 @@ namespace rabit { namespace engine { AllreduceRobust::AllreduceRobust(void) { result_buffer_round = 1; + num_local_replica = 2; seq_counter = 0; } /*! \brief shutdown the engine */ @@ -108,22 +109,38 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) } /*! * \brief load latest check point - * \param p_model pointer to the model + * \param global_model pointer to the globally shared model/state + * when calling this function, the caller need to gauranttees that global_model + * is the same in all nodes + * \param local_model pointer to local model, that is specific to current node/rank + * this can be NULL when no local model is needed + * * \return the version number of check point loaded * if returned version == 0, this means no model has been CheckPointed * the p_model is not touched, user should do necessary initialization by themselves + * + * Common usage example: + * int iter = rabit::LoadCheckPoint(&model); + * if (iter == 0) model.InitParameters(); + * for (i = iter; i < max_iter; ++i) { + * do many things, include allreduce + * rabit::CheckPoint(model); + * } + * * \sa CheckPoint, VersionNumber */ -int AllreduceRobust::LoadCheckPoint(utils::ISerializable *p_model) { +int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model, + utils::ISerializable *local_model) { + utils::Check(local_model == NULL, "CheckPoint local_model is not yet supported"); // check if we succesfll if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kMaxSeq)) { // reset result buffer resbuf.Clear(); seq_counter = 0; // load from buffer - utils::MemoryBufferStream fs(&checked_model); + utils::MemoryBufferStream fs(&mglobal_model); fs.Read(&version_number, sizeof(version_number)); if (version_number == 0) return version_number; - p_model->Load(fs); + global_model->Load(fs); // run another phase of check ack, if recovered from data utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq), "check ack must return true"); @@ -139,20 +156,31 @@ int AllreduceRobust::LoadCheckPoint(utils::ISerializable *p_model) { * \brief checkpoint the model, meaning we finished a stage of execution * every time we call check point, there is a version number which will increase by one * - * \param p_model pointer to the model + * \param global_model pointer to the globally shared model/state + * when calling this function, the caller need to gauranttees that global_model + * is the same in all nodes + * \param local_model pointer to local model, that is specific to current node/rank + * this can be NULL when no local state is needed + * + * NOTE: local_model requires explicit replication of the model for fault-tolerance, which will + * bring replication cost in CheckPoint function. global_model do not need explicit replication. + * So only CheckPoint with global_model if possible + * * \sa LoadCheckPoint, VersionNumber */ -void AllreduceRobust::CheckPoint(const utils::ISerializable &model) { - // increase version number - version_number += 1; - // save model - checked_model.resize(0); - utils::MemoryBufferStream fs(&checked_model); - fs.Write(&version_number, sizeof(version_number)); - model.Save(fs); +void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model, + const utils::ISerializable *local_model) { + utils::Assert(local_model == NULL, "CheckPoint local model is not supported yet"); // execute checkpoint, note: when checkpoint existing, load will not happen utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq), "check point must return true"); + // increase version number + version_number += 1; + // save model + mglobal_model.resize(0); + utils::MemoryBufferStream fs(&mglobal_model); + fs.Write(&version_number, sizeof(version_number)); + global_model->Save(fs); // reset result buffer resbuf.Clear(); seq_counter = 0; // execute check ack step, load happens here @@ -488,6 +516,10 @@ AllreduceRobust::TryRecoverData(RecoverType role, } if (finished) break; selecter.Select(); + // exception handling + for (int i = 0; i < nlink; ++i) { + if (selecter.CheckExcept(links[i].sock)) return kGetExcept; + } if (role == kRequestData) { const int pid = recv_link; if (selecter.CheckRead(links[pid].sock)) { @@ -548,16 +580,16 @@ AllreduceRobust::TryRecoverData(RecoverType role, */ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) { RecoverType role = requester ? kRequestData : kHaveData; - size_t size = this->checked_model.length(); + size_t size = this->mglobal_model.length(); int recv_link; std::vector req_in; ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in); if (succ != kSuccess) return succ; if (role == kRequestData) { - checked_model.resize(size); + mglobal_model.resize(size); } if (size == 0) return kSuccess; - return TryRecoverData(role, &checked_model[0], size, recv_link, req_in); + return TryRecoverData(role, &mglobal_model[0], size, recv_link, req_in); } /*! * \brief try to get the result of operation specified by seqno @@ -674,6 +706,81 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) { utils::Assert(false, "RecoverExec: should not reach here"); return true; } + +/*! + * \brief perform a ring passing to receive data from prev link, and sent data to next link + * this allows data to stream over a ring structure + * sendrecvbuf[0:read_ptr] are already provided by current node + * current node will recv sendrecvbuf[read_ptr:read_end] from prev link + * current node will send sendrecvbuf[write_ptr:write_end] to next link + * write_ptr will wait till the data is readed before sending the data + * this function requires read_end >= write_end + * + * \param sendrecvbuf_ the place to hold the incoming and outgoing data + * \param read_ptr the initial read pointer + * \param read_end the ending position to read + * \param write_ptr the initial write pointer + * \param write_end the ending position to write + * \param prev pointer to link to previous position in ring + * \param prev pointer to link of next position in ring + */ +AllreduceRobust::ReturnType +AllreduceRobust::RingPassing(void *sendrecvbuf_, + size_t read_ptr, + size_t read_end, + size_t write_ptr, + size_t write_end, + LinkRecord *prev_link, + LinkRecord *next_link) { + if (links.size() == 0 || read_end == 0) return kSuccess; + utils::Assert(read_end <= write_end, "boundary check"); + utils::Assert(read_ptr <= read_end, "boundary check"); + utils::Assert(write_ptr <= write_end, "boundary check"); + // take reference + LinkRecord &prev = *prev_link, &next = *next_link; + // send recv buffer + char *buf = reinterpret_cast(sendrecvbuf_); + while (true) { + bool finished = true; + utils::SelectHelper selecter; + if (read_ptr != read_end) { + selecter.WatchRead(prev.sock); + finished = false; + } + if (write_ptr < read_ptr && write_ptr != write_end) { + selecter.WatchWrite(next.sock); + finished = false; + } + selecter.WatchException(prev.sock); + selecter.WatchException(next.sock); + if (finished) break; + selecter.Select(); + if (selecter.CheckExcept(prev.sock)) return kGetExcept; + if (selecter.CheckExcept(next.sock)) return kGetExcept; + if (read_ptr != read_end && selecter.CheckRead(prev.sock)) { + ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr); + if (len == 0) { + prev.sock.Close(); return kSockError; + } + if (len != -1) { + read_ptr += static_cast(len); + } else { + if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError; + } + } + if (write_ptr != write_end && write_ptr < read_ptr && + selecter.CheckWrite(next.sock)) { + size_t nsend = std::min(write_end - write_ptr, read_ptr - write_ptr); + ssize_t len = next.sock.Send(buf + write_ptr, nsend); + if (len != -1) { + write_ptr += static_cast(len); + } else { + if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError; + } + } + } + return kSuccess; +} } // namespace engine } // namespace rabit diff --git a/src/allreduce_robust.h b/src/allreduce_robust.h index ad660da94..d1018907c 100644 --- a/src/allreduce_robust.h +++ b/src/allreduce_robust.h @@ -49,21 +49,46 @@ class AllreduceRobust : public AllreduceBase { virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root); /*! * \brief load latest check point - * \param p_model pointer to the model + * \param global_model pointer to the globally shared model/state + * when calling this function, the caller need to gauranttees that global_model + * is the same in all nodes + * \param local_model pointer to local model, that is specific to current node/rank + * this can be NULL when no local model is needed + * * \return the version number of check point loaded * if returned version == 0, this means no model has been CheckPointed * the p_model is not touched, user should do necessary initialization by themselves + * + * Common usage example: + * int iter = rabit::LoadCheckPoint(&model); + * if (iter == 0) model.InitParameters(); + * for (i = iter; i < max_iter; ++i) { + * do many things, include allreduce + * rabit::CheckPoint(model); + * } + * * \sa CheckPoint, VersionNumber */ - virtual int LoadCheckPoint(utils::ISerializable *p_model); + virtual int LoadCheckPoint(utils::ISerializable *global_model, + utils::ISerializable *local_model = NULL); /*! * \brief checkpoint the model, meaning we finished a stage of execution * every time we call check point, there is a version number which will increase by one * - * \param p_model pointer to the model + * \param global_model pointer to the globally shared model/state + * when calling this function, the caller need to gauranttees that global_model + * is the same in all nodes + * \param local_model pointer to local model, that is specific to current node/rank + * this can be NULL when no local state is needed + * + * NOTE: local_model requires explicit replication of the model for fault-tolerance, which will + * bring replication cost in CheckPoint function. global_model do not need explicit replication. + * So only CheckPoint with global_model if possible + * * \sa LoadCheckPoint, VersionNumber */ - virtual void CheckPoint(const utils::ISerializable &model); + virtual void CheckPoint(const utils::ISerializable *global_model, + const utils::ISerializable *local_model = NULL); /*! * \brief explicitly re-init everything before calling LoadCheckPoint * call this function when IEngine throw an exception out, @@ -259,7 +284,7 @@ class AllreduceRobust : public AllreduceBase { * result by recovering procedure, the action is complete, no further action is needed * - false means this is the lastest action that has not yet been executed, need to execute the action */ - bool RecoverExec(void *buf, size_t size, int flag, int seqno = ActionSummary::kMaxSeq); + bool RecoverExec(void *buf, size_t size, int flag, int seqno = ActionSummary::kMaxSeq); /*! * \brief try to load check point * @@ -325,6 +350,30 @@ class AllreduceRobust : public AllreduceBase { size_t size, int recv_link, const std::vector &req_in); + /*! + * \brief perform a ring passing to receive data from prev link, and sent data to next link + * this allows data to stream over a ring structure + * sendrecvbuf[0:read_ptr] are already provided by current node + * current node will recv sendrecvbuf[read_ptr:read_end] from prev link + * current node will send sendrecvbuf[write_ptr:write_end] to next link + * write_ptr will wait till the data is readed before sending the data + * this function requires read_end >= write_end + * + * \param sendrecvbuf_ the place to hold the incoming and outgoing data + * \param read_ptr the initial read pointer + * \param read_end the ending position to read + * \param write_ptr the initial write pointer + * \param write_end the ending position to write + * \param prev pointer to link to previous position in ring + * \param prev pointer to link of next position in ring + */ + ReturnType RingPassing(void *senrecvbuf_, + size_t read_ptr, + size_t read_end, + size_t write_ptr, + size_t write_end, + LinkRecord *prev_link, + LinkRecord *next_link); /*! * \brief run message passing algorithm on the allreduce tree * the result is edge message stored in p_edge_in and p_edge_out @@ -358,10 +407,21 @@ class AllreduceRobust : public AllreduceBase { int seq_counter; // the round of result buffer, used to mode the result int result_buffer_round; - // result buffer + // result buffer of all reduce ResultBuffer resbuf; - // last check point model - std::string checked_model; + // last check point global model + std::string mglobal_model; + // number of replica for local state/model + int num_local_replica; + // pointer to memory position in the local model + // local model is stored in CSR format(like a sparse matrices) + // local_model[rptr[0]:rptr[1]] stores the model of current node + // local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops in the ring + std::vector local_rptr; + // storage for local model replicas + std::string mlocal_model; + // temporal storage + std::string tmp_local_model; }; } // namespace engine } // namespace rabit diff --git a/src/engine.h b/src/engine.h index 6d95fe5dc..e393e94db 100644 --- a/src/engine.h +++ b/src/engine.h @@ -60,7 +60,12 @@ class IEngine { virtual void InitAfterException(void) = 0; /*! * \brief load latest check point - * \param p_model pointer to the model + * \param global_model pointer to the globally shared model/state + * when calling this function, the caller need to gauranttees that global_model + * is the same in all nodes + * \param local_model pointer to local model, that is specific to current node/rank + * this can be NULL when no local model is needed + * * \return the version number of check point loaded * if returned version == 0, this means no model has been CheckPointed * the p_model is not touched, user should do necessary initialization by themselves @@ -75,15 +80,26 @@ class IEngine { * * \sa CheckPoint, VersionNumber */ - virtual int LoadCheckPoint(utils::ISerializable *p_model) = 0; + virtual int LoadCheckPoint(utils::ISerializable *global_model, + utils::ISerializable *local_model = NULL) = 0; /*! * \brief checkpoint the model, meaning we finished a stage of execution * every time we call check point, there is a version number which will increase by one * - * \param p_model pointer to the model + * \param global_model pointer to the globally shared model/state + * when calling this function, the caller need to gauranttees that global_model + * is the same in all nodes + * \param local_model pointer to local model, that is specific to current node/rank + * this can be NULL when no local state is needed + * + * NOTE: local_model requires explicit replication of the model for fault-tolerance, which will + * bring replication cost in CheckPoint function. global_model do not need explicit replication. + * So only CheckPoint with global_model if possible + * * \sa LoadCheckPoint, VersionNumber */ - virtual void CheckPoint(const utils::ISerializable &model) = 0; + virtual void CheckPoint(const utils::ISerializable *global_model, + const utils::ISerializable *local_model = NULL) = 0; /*! * \return version number of current stored model, * which means how many calls to CheckPoint we made so far diff --git a/src/engine_mpi.cc b/src/engine_mpi.cc index 03bd0cb73..f32dba854 100644 --- a/src/engine_mpi.cc +++ b/src/engine_mpi.cc @@ -32,10 +32,12 @@ class MPIEngine : public IEngine { virtual void InitAfterException(void) { utils::Error("MPI is not fault tolerant"); } - virtual int LoadCheckPoint(utils::ISerializable *p_model) { + virtual int LoadCheckPoint(utils::ISerializable *global_model, + utils::ISerializable *local_model = NULL) { return 0; } - virtual void CheckPoint(const utils::ISerializable &model) { + virtual void CheckPoint(const utils::ISerializable *global_model, + const utils::ISerializable *local_model = NULL) { version_number += 1; } virtual int VersionNumber(void) const { diff --git a/src/rabit-inl.h b/src/rabit-inl.h index f3fd39b2a..b13ea88fc 100644 --- a/src/rabit-inl.h +++ b/src/rabit-inl.h @@ -129,7 +129,7 @@ inline int LoadCheckPoint(utils::ISerializable *p_model) { } // checkpoint the model, meaning we finished a stage of execution inline void CheckPoint(const utils::ISerializable &model) { - engine::GetEngine()->CheckPoint(model); + engine::GetEngine()->CheckPoint(&model); } // return the version number of currently stored model inline int VersionNumber(void) {