diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 90a32dbee..72fa12e79 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -7,6 +7,7 @@ #define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_DEPRECATE #define NOMINMAX +#include #include #include #include "./allreduce_base.h" @@ -43,17 +44,18 @@ void AllreduceBase::Init(void) { } // start socket utils::Socket::Startup(); - utils::Assert(links.size() == 0, "can only call Init once"); + utils::Assert(all_links.size() == 0, "can only call Init once"); this->host_uri = utils::SockAddr::GetHostName(); // get information from tracker this->ReConnectLinks(); } void AllreduceBase::Shutdown(void) { - for (size_t i = 0; i < links.size(); ++i) { - links[i].sock.Close(); + for (size_t i = 0; i < all_links.size(); ++i) { + all_links[i].sock.Close(); } - links.clear(); + all_links.clear(); + tree_links.plinks.clear(); if (tracker_uri == "NULL") return; int magic = kMagic; @@ -121,8 +123,12 @@ void AllreduceBase::ReConnectLinks(const char *cmd) { utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3"); tracker.SendStr(task_id); tracker.SendStr(std::string(cmd)); + // the rank of previous link, next link in ring + int prev_rank, next_rank; + // the rank of neighbors + std::map tree_neighbors; {// get new ranks - int newrank; + int newrank, num_neighbors; utils::Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank), "ReConnectLink failure 4"); utils::Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) == sizeof(parent_rank), @@ -130,8 +136,20 @@ void AllreduceBase::ReConnectLinks(const char *cmd) { utils::Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 4"); utils::Assert(rank == -1 || newrank == rank, "must keep rank to same if the node already have one"); - rank = newrank; - } + rank = newrank; + utils::Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == sizeof(num_neighbors), + "ReConnectLink failure 4"); + for (int i = 0; i < num_neighbors; ++i) { + int nrank; + utils::Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank), + "ReConnectLink failure 4"); + tree_neighbors[nrank] = 1; + } + utils::Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank), + "ReConnectLink failure 4"); + utils::Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank), + "ReConnectLink failure 4"); + } // create listening socket utils::TCPSocket sock_listen; sock_listen.Create(); @@ -144,11 +162,11 @@ void AllreduceBase::ReConnectLinks(const char *cmd) { do { // send over good links std::vector good_link; - for (size_t i = 0; i < links.size(); ++i) { - if (!links[i].sock.BadSocket()) { - good_link.push_back(static_cast(links[i].rank)); + for (size_t i = 0; i < all_links.size(); ++i) { + if (!all_links[i].sock.BadSocket()) { + good_link.push_back(static_cast(all_links[i].rank)); } else { - if (!links[i].sock.IsClosed()) links[i].sock.Close(); + if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close(); } } int ngood = static_cast(good_link.size()); @@ -178,13 +196,13 @@ void AllreduceBase::ReConnectLinks(const char *cmd) { utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 13"); utils::Check(hrank == r.rank, "ReConnectLink failure, link rank inconsistent"); bool match = false; - for (size_t i = 0; i < links.size(); ++i) { - if (links[i].rank == hrank) { - utils::Assert(links[i].sock.IsClosed(), "Override a link that is active"); - links[i].sock = r.sock; match = true; break; + for (size_t i = 0; i < all_links.size(); ++i) { + if (all_links[i].rank == hrank) { + utils::Assert(all_links[i].sock.IsClosed(), "Override a link that is active"); + all_links[i].sock = r.sock; match = true; break; } } - if (!match) links.push_back(r); + if (!match) all_links.push_back(r); } utils::Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error), "ReConnectLink failure 14"); } while (num_error != 0); @@ -199,27 +217,35 @@ void AllreduceBase::ReConnectLinks(const char *cmd) { utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 15"); utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 15"); bool match = false; - for (size_t i = 0; i < links.size(); ++i) { - if (links[i].rank == r.rank) { - utils::Assert(links[i].sock.IsClosed(), "Override a link that is active"); - links[i].sock = r.sock; match = true; break; + for (size_t i = 0; i < all_links.size(); ++i) { + if (all_links[i].rank == r.rank) { + utils::Assert(all_links[i].sock.IsClosed(), "Override a link that is active"); + all_links[i].sock = r.sock; match = true; break; } } - if (!match) links.push_back(r); + if (!match) all_links.push_back(r); } // close listening sockets sock_listen.Close(); this->parent_index = -1; - // setup selecter - for (size_t i = 0; i < links.size(); ++i) { - utils::Assert(!links[i].sock.BadSocket(), "ReConnectLink: bad socket"); + // setup tree links and ring structure + tree_links.plinks.clear(); + for (size_t i = 0; i < all_links.size(); ++i) { + utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket"); // set the socket to non-blocking mode - links[i].sock.SetNonBlock(true); - if (links[i].rank == parent_rank) parent_index = static_cast(i); - } - if (parent_rank != -1) { - utils::Assert(parent_index != -1, "cannot find parent in the link"); + all_links[i].sock.SetNonBlock(true); + if (tree_neighbors.count(all_links[i].rank) != 0) { + if (all_links[i].rank == parent_rank) { + parent_index = static_cast(tree_links.plinks.size()); + } + tree_links.plinks.push_back(&all_links[i]); + } + if (all_links[i].rank == prev_rank) ring_prev = &all_links[i]; + if (all_links[i].rank == next_rank) ring_next = &all_links[i]; } + utils::Assert(parent_rank == -1 || parent_index != -1, "cannot find parent in the link"); + utils::Assert(prev_rank == -1 || ring_prev != NULL, "cannot find prev ring in the link"); + utils::Assert(next_rank == -1 || ring_next != NULL, "cannot find next ring in the link"); } /*! * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure @@ -241,6 +267,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, ReduceFunction reducer) { + RefLinkVector &links = tree_links; if (links.size() == 0 || count == 0) return kSuccess; // total size of message const size_t total_size = type_nbytes * count; @@ -391,8 +418,9 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_, */ AllreduceBase::ReturnType AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { + RefLinkVector &links = tree_links; if (links.size() == 0 || total_size == 0) return kSuccess; - utils::Check(root < world_size, "Broadcast: root should be smaller than world size"); + utils::Check(root < world_size, "Broadcast: root should be smaller than world size"); // number of links const int nlink = static_cast(links.size()); // size of space already read from data diff --git a/src/allreduce_base.h b/src/allreduce_base.h index 6eea948ce..4ef4a044e 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -259,6 +259,19 @@ class AllreduceBase : public IEngine { // aligned with 64 bits, will be able to perform 64 bits operations freely std::vector buffer_; }; + /*! + * \brief simple data structure that works like a vector + * but takes reference instead of space + */ + struct RefLinkVector { + std::vector plinks; + inline LinkRecord &operator[](size_t i) { + return *plinks[i]; + } + inline size_t size(void) const { + return plinks.size(); + } + }; /*! * \brief connect to the tracker to fix the the missing links * this function is also used when the engine start up @@ -306,9 +319,11 @@ class AllreduceBase : public IEngine { int parent_index; // rank of parent node, can be -1 int parent_rank; - // sockets of all links - std::vector links; - // pointer to someplace in the ring + // sockets of all links this connects to + std::vector all_links; + // all the links in the reduction tree connection + RefLinkVector tree_links; + // pointer to links in the ring LinkRecord *ring_prev, *ring_next; //----- meta information----- // unique identifier of the possible job this process is doing diff --git a/src/allreduce_robust-inl.h b/src/allreduce_robust-inl.h index f1f557593..49f8f2c37 100644 --- a/src/allreduce_robust-inl.h +++ b/src/allreduce_robust-inl.h @@ -37,6 +37,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value, const std::vector &edge_in, size_t out_index) ) { + RefLinkVector &links = tree_links; if (links.size() == 0) return kSuccess; // number of links const int nlink = static_cast(links.size()); diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index dbb318c33..c1d5119cc 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -11,13 +11,14 @@ #include #include "./io.h" #include "./utils.h" +#include "./rabit.h" #include "./allreduce_robust.h" namespace rabit { namespace engine { AllreduceRobust::AllreduceRobust(void) { result_buffer_round = 1; - num_local_replica = 2; + num_local_replica = 0; seq_counter = 0; } /*! \brief shutdown the engine */ @@ -131,9 +132,17 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) */ int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model, utils::ISerializable *local_model) { - utils::Check(local_model == NULL, "CheckPoint local_model is not yet supported"); - // check if we succesfll + if (num_local_replica == 0) { + utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model"); + } + // check if we succesful if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp)) { + if (local_model != NULL) { + // load in local model + utils::MemoryFixSizeBuffer fs(BeginPtr(local_chkpt[local_chkpt_version]), + local_rptr[local_chkpt_version][1]); + local_model->Load(fs); + } // reset result buffer resbuf.Clear(); seq_counter = 0; // load from buffer @@ -170,7 +179,31 @@ int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model, */ void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model, const utils::ISerializable *local_model) { - utils::Assert(local_model == NULL, "CheckPoint local model is not supported yet"); + if (num_local_replica == 0) { + utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model"); + } + if (num_local_replica != 0) { + while (true) { + if (RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckPoint)) break; + // save model model to new version place + int new_version = !local_chkpt_version; + local_chkpt[new_version].clear(); + utils::MemoryBufferStream fs(&local_chkpt[new_version]); + if (local_model != NULL) { + local_model->Save(fs); + } + local_rptr[new_version].clear(); + local_rptr[new_version].push_back(0); + local_rptr[new_version].push_back(local_chkpt[new_version].length()); + if (CheckAndRecover(TryCheckinLocalState(&local_rptr[new_version], + &local_chkpt[new_version]))) break; + } + // run the ack phase + utils::Assert(RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckAck), + "check point must return true"); + // switch pointer to new version + local_chkpt_version = !local_chkpt_version; + } // execute checkpoint, note: when checkpoint existing, load will not happen utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp), "check point must return true"); @@ -199,32 +232,32 @@ void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model, */ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { // number of links - const int nlink = static_cast(links.size()); + const int nlink = static_cast(all_links.size()); for (int i = 0; i < nlink; ++i) { - links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size); - links[i].ResetSize(); + all_links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size); + all_links[i].ResetSize(); } // read and discard data from all channels until pass mark while (true) { for (int i = 0; i < nlink; ++i) { - if (links[i].sock.BadSocket()) continue; - if (links[i].size_write == 0) { + if (all_links[i].sock.BadSocket()) continue; + if (all_links[i].size_write == 0) { char sig = kOOBReset; - ssize_t len = links[i].sock.Send(&sig, sizeof(sig), MSG_OOB); + ssize_t len = all_links[i].sock.Send(&sig, sizeof(sig), MSG_OOB); // error will be filtered in next loop - if (len == sizeof(sig)) links[i].size_write = 1; + if (len == sizeof(sig)) all_links[i].size_write = 1; } - if (links[i].size_write == 1) { + if (all_links[i].size_write == 1) { char sig = kResetMark; - ssize_t len = links[i].sock.Send(&sig, sizeof(sig)); - if (len == sizeof(sig)) links[i].size_write = 2; + ssize_t len = all_links[i].sock.Send(&sig, sizeof(sig)); + if (len == sizeof(sig)) all_links[i].size_write = 2; } } utils::SelectHelper rsel; bool finished = true; for (int i = 0; i < nlink; ++i) { - if (links[i].size_write != 2 && !links[i].sock.BadSocket()) { - rsel.WatchWrite(links[i].sock); finished = false; + if (all_links[i].size_write != 2 && !all_links[i].sock.BadSocket()) { + rsel.WatchWrite(all_links[i].sock); finished = false; } } if (finished) break; @@ -232,32 +265,32 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { rsel.Select(); } for (int i = 0; i < nlink; ++i) { - if (!links[i].sock.BadSocket()) { - utils::SelectHelper::WaitExcept(links[i].sock); + if (!all_links[i].sock.BadSocket()) { + utils::SelectHelper::WaitExcept(all_links[i].sock); } } while (true) { for (int i = 0; i < nlink; ++i) { - if (links[i].size_read == 0) { - int atmark = links[i].sock.AtMark(); + if (all_links[i].size_read == 0) { + int atmark = all_links[i].sock.AtMark(); if (atmark < 0) { - utils::Assert(links[i].sock.BadSocket(), "must already gone bad"); + utils::Assert(all_links[i].sock.BadSocket(), "must already gone bad"); } else if (atmark > 0) { - links[i].size_read = 1; + all_links[i].size_read = 1; } else { // no at mark, read and discard data - ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size); - if (links[i].sock.AtMark()) links[i].size_read = 1; + ssize_t len = all_links[i].sock.Recv(all_links[i].buffer_head, all_links[i].buffer_size); + if (all_links[i].sock.AtMark()) all_links[i].size_read = 1; // zero length, remote closed the connection, close socket - if (len == 0) links[i].sock.Close(); + if (len == 0) all_links[i].sock.Close(); } } } utils::SelectHelper rsel; bool finished = true; for (int i = 0; i < nlink; ++i) { - if (links[i].size_read == 0 && !links[i].sock.BadSocket()) { - rsel.WatchRead(links[i].sock); finished = false; + if (all_links[i].size_read == 0 && !all_links[i].sock.BadSocket()) { + rsel.WatchRead(all_links[i].sock); finished = false; } } if (finished) break; @@ -266,22 +299,22 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { // start synchronization, use blocking I/O to avoid select for (int i = 0; i < nlink; ++i) { - if (!links[i].sock.BadSocket()) { + if (!all_links[i].sock.BadSocket()) { char oob_mark; - links[i].sock.SetNonBlock(false); - ssize_t len = links[i].sock.Recv(&oob_mark, sizeof(oob_mark), MSG_WAITALL); + all_links[i].sock.SetNonBlock(false); + ssize_t len = all_links[i].sock.Recv(&oob_mark, sizeof(oob_mark), MSG_WAITALL); if (len == 0) { - links[i].sock.Close(); continue; + all_links[i].sock.Close(); continue; } else if (len > 0) { utils::Assert(oob_mark == kResetMark, "wrong oob msg"); - utils::Assert(links[i].sock.AtMark() != 1, "should already read past mark"); + utils::Assert(all_links[i].sock.AtMark() != 1, "should already read past mark"); } else { utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); } // send out ack char ack = kResetAck; while (true) { - len = links[i].sock.Send(&ack, sizeof(ack)); + len = all_links[i].sock.Send(&ack, sizeof(ack)); if (len == sizeof(ack)) break; if (len == -1) { if (errno != EAGAIN && errno != EWOULDBLOCK) break; @@ -291,22 +324,22 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { } // wait all ack for (int i = 0; i < nlink; ++i) { - if (!links[i].sock.BadSocket()) { + if (!all_links[i].sock.BadSocket()) { char ack; - ssize_t len = links[i].sock.Recv(&ack, sizeof(ack), MSG_WAITALL); + ssize_t len = all_links[i].sock.Recv(&ack, sizeof(ack), MSG_WAITALL); if (len == 0) { - links[i].sock.Close(); continue; + all_links[i].sock.Close(); continue; } else if (len > 0) { utils::Assert(ack == kResetAck, "wrong Ack MSG"); } else { utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); } // set back to nonblock mode - links[i].sock.SetNonBlock(true); + all_links[i].sock.SetNonBlock(true); } } for (int i = 0; i < nlink; ++i) { - if (links[i].sock.BadSocket()) return kSockError; + if (all_links[i].sock.BadSocket()) return kSockError; } return kSuccess; } @@ -320,8 +353,8 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { bool AllreduceRobust::CheckAndRecover(ReturnType err_type) { if (err_type == kSuccess) return true; // simple way, shutdown all links - for (size_t i = 0; i < links.size(); ++i) { - if (!links[i].sock.BadSocket()) links[i].sock.Close(); + for (size_t i = 0; i < all_links.size(); ++i) { + if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close(); } ReConnectLinks("recover"); return false; @@ -479,6 +512,7 @@ AllreduceRobust::TryRecoverData(RecoverType role, size_t size, int recv_link, const std::vector &req_in) { + RefLinkVector &links = tree_links; // no need to run recovery for zero size message if (links.size() == 0 || size == 0) return kSuccess; utils::Assert(req_in.size() == links.size(), "TryRecoverData"); @@ -580,17 +614,48 @@ AllreduceRobust::TryRecoverData(RecoverType role, * \sa ReturnType */ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) { - RecoverType role = requester ? kRequestData : kHaveData; + // check in local data + RecoverType role = requester ? kRequestData : kHaveData; + ReturnType succ; + if (num_local_replica != 0) { + if (requester) { + // clear existing history, if any, before load + local_rptr[local_chkpt_version].clear(); + local_chkpt[local_chkpt_version].clear(); + } + // recover local checkpoint + succ = TryRecoverLocalState(&local_rptr[local_chkpt_version], + &local_chkpt[local_chkpt_version]); + if (succ != kSuccess) return succ; + int nlocal = std::max(static_cast(local_rptr[local_chkpt_version].size()) - 1, 0); + // check if everyone is OK + unsigned state = 0; + if (nlocal == num_local_replica + 1) { + // complete recovery + state = 1; + } else if (nlocal == 0) { + // get nothing + state = 2; + } else { + // partially complete state + state = 4; + } + succ = TryAllreduce(&state, sizeof(state), 1, op::Reducer); + if (succ != kSuccess) return succ; + utils::Check(state == 1 || state == 2, + "LoadCheckPoint: too many nodes fails, cannot recover local state"); + } + // recover global checkpoint size_t size = this->global_checkpoint.length(); int recv_link; std::vector req_in; - ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in); + succ = TryDecideRouting(role, &size, &recv_link, &req_in); if (succ != kSuccess) return succ; if (role == kRequestData) { global_checkpoint.resize(size); } if (size == 0) return kSuccess; - return TryRecoverData(role, &global_checkpoint[0], size, recv_link, req_in); + return TryRecoverData(role, BeginPtr(global_checkpoint), size, recv_link, req_in); } /*! * \brief try to get the result of operation specified by seqno @@ -607,11 +672,21 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) { * \sa ReturnType */ AllreduceRobust::ReturnType -AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) { RecoverType role; +AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) { // if minimum sequence requested is local check point ack, // this means all nodes have finished local check point, directly return if (seqno == ActionSummary::kLocalCheckAck) return kSuccess; - + if (seqno == ActionSummary::kLocalCheckPoint) { + // new version of local model + int new_version = !local_chkpt_version; + int nlocal = std::max(static_cast(local_rptr[new_version].size()) - 1, 0); + // if we goes to this place, use must have already setup the state once + utils::Assert(nlocal == 1 || nlocal == num_local_replica + 1, + "TryGetResult::Checkpoint"); + return TryRecoverLocalState(&local_rptr[new_version], &local_chkpt[new_version]); + } + // handles normal data recovery + RecoverType role; if (!requester) { sendrecvbuf = resbuf.Query(seqno, &size); role = sendrecvbuf != NULL ? kHaveData : kPassData; @@ -786,7 +861,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, } chkpt.resize(rptr.back()); // pass data through the link - succ = RingPassing(&chkpt[0], rptr[nlocal], rptr[nread_end], + succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end], rptr[nwrite_start], rptr[nread_end], ring_next, ring_prev); if (succ != kSuccess) { @@ -849,7 +924,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, } chkpt.resize(rptr.back()); // pass data through the link - succ = RingPassing(&chkpt[0], rptr[nlocal], rptr[nread_end], + succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end], rptr[nwrite_start], rptr[nwrite_end], ring_prev, ring_next); if (succ != kSuccess) { @@ -858,6 +933,57 @@ AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, } return kSuccess; } +/*! + * \brief try to checkpoint local state, this function is called in normal executation phase + * of checkpoint that contains local state + * the input state must exactly one saved state(local state of current node), + * after complete, this function will get local state from previous num_local_replica nodes and put them + * into local_chkpt and local_rptr + * + * It is also OK to call TryRecoverLocalState instead, + * TryRecoverLocalState makes less assumption about the input, and requires more communications + * + * \param p_local_rptr the pointer to the segment pointers in the states array + * \param p_local_chkpt the pointer to the storage of local check points + * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details + * \sa ReturnType, TryRecoverLocalState + */ +AllreduceRobust::ReturnType +AllreduceRobust::TryCheckinLocalState(std::vector *p_local_rptr, + std::string *p_local_chkpt) { + // if there is no local replica, we can do nothing + if (num_local_replica == 0) return kSuccess; + std::vector &rptr = *p_local_rptr; + std::string &chkpt = *p_local_chkpt; + utils::Assert(rptr.size() == 2, "TryCheckinLocalState must have exactly 1 state"); + const int n = num_local_replica; + std::vector sizes(n + 1); + sizes[0] = rptr[1] - rptr[0]; + ReturnType succ; + // pass size through the link + succ = RingPassing(BeginPtr(sizes), + 1 * sizeof(size_t), + (n + 1) * sizeof(size_t), + 0 * sizeof(size_t), + n * sizeof(size_t), + ring_prev, ring_next); + if (succ != kSuccess) return succ; + // update rptr + rptr.resize(n + 1); + for (int i = 1; i < n; ++i) { + rptr[i + 1] = rptr[i] + sizes[i]; + } + chkpt.resize(rptr.back()); + // pass data through the link + succ = RingPassing(BeginPtr(chkpt), + rptr[1], rptr[n + 1], + rptr[0], rptr[n], + ring_prev, ring_next); + if (succ != kSuccess) { + rptr.resize(2); chkpt.resize(rptr.back()); return succ; + } + return kSuccess; +} /*! * \brief perform a ring passing to receive data from prev link, and sent data to next link * this allows data to stream over a ring structure @@ -883,7 +1009,7 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_, size_t write_end, LinkRecord *read_link, LinkRecord *write_link) { - if (links.size() == 0 || read_end == 0) return kSuccess; + if (read_link == NULL || write_link == NULL || read_end == 0) return kSuccess; utils::Assert(read_end <= write_end, "boundary check"); utils::Assert(read_ptr <= read_end, "boundary check"); utils::Assert(write_ptr <= write_end, "boundary check"); diff --git a/src/allreduce_robust.h b/src/allreduce_robust.h index 570960e52..e43e9ac66 100644 --- a/src/allreduce_robust.h +++ b/src/allreduce_robust.h @@ -372,6 +372,23 @@ class AllreduceRobust : public AllreduceBase { */ ReturnType TryRecoverLocalState(std::vector *p_local_rptr, std::string *p_local_chkpt); + /*! + * \brief try to checkpoint local state, this function is called in normal executation phase + * of checkpoint that contains local state +o * the input state must exactly one saved state(local state of current node), + * after complete, this function will get local state from previous num_local_replica nodes and put them + * into local_chkpt and local_rptr + * + * It is also OK to call TryRecoverLocalState instead, + * TryRecoverLocalState makes less assumption about the input, and requires more communications + * + * \param p_local_rptr the pointer to the segment pointers in the states array + * \param p_local_chkpt the pointer to the storage of local check points + * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details + * \sa ReturnType, TryRecoverLocalState + */ + ReturnType TryCheckinLocalState(std::vector *p_local_rptr, + std::string *p_local_chkpt); /*! * \brief perform a ring passing to receive data from prev link, and sent data to next link * this allows data to stream over a ring structure @@ -441,7 +458,7 @@ class AllreduceRobust : public AllreduceBase { // local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops in the ring std::vector local_rptr[2]; // storage for local model replicas - std::string local_checkpoint[2]; + std::string local_chkpt[2]; // version of local checkpoint can be 1 or 0 int local_chkpt_version; }; diff --git a/src/rabit_tracker.py b/src/rabit_tracker.py index ceda6347f..fe01a87da 100644 --- a/src/rabit_tracker.py +++ b/src/rabit_tracker.py @@ -63,25 +63,32 @@ class SlaveEntry: return job_map[self.jobid] return -1 - def get_neighbor(self, rank, nslave): - rank = rank + 1 - ret = [] - if rank > 1: - ret.append(rank / 2 - 1) - if rank * 2 - 1 < nslave: - ret.append(rank * 2 - 1) - if rank * 2 < nslave: - ret.append(rank * 2) - return set(ret) - - def assign_rank(self, rank, wait_conn, nslave): + def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map): self.rank = rank - nnset = self.get_neighbor(rank, nslave) + nnset = set(tree_map[rank]) + rprev, rnext = ring_map[rank] self.sock.sendint(rank) # send parent rank - self.sock.sendint((rank + 1) / 2 - 1) + self.sock.sendint(parent_map[rank]) # send world size - self.sock.sendint(nslave) + self.sock.sendint(len(tree_map)) + self.sock.sendint(len(nnset)) + # send the rprev and next link + for r in nnset: + self.sock.sendint(r) + # send prev link + if rprev != -1 and rprev != rank: + nnset.add(rprev) + self.sock.sendint(rprev) + else: + self.sock.sendint(-1) + # send next link + if rnext != -1 and rnext != rank: + nnset.add(rnext) + self.sock.sendint(rnext) + else: + self.sock.sendint(-1) + while True: ngood = self.sock.recvint() goodset = set([]) @@ -131,8 +138,35 @@ class Tracker: self.sock.close() def slave_args(self): return ['rabit_tracker_uri=%s' % socket.gethostname(), - 'rabit_tracker_port=%s' % self.port] + 'rabit_tracker_port=%s' % self.port] + def get_neighbor(self, rank, nslave): + rank = rank + 1 + ret = [] + if rank > 1: + ret.append(rank / 2 - 1) + if rank * 2 - 1 < nslave: + ret.append(rank * 2 - 1) + if rank * 2 < nslave: + ret.append(rank * 2) + return ret + def get_tree(self, nslave): + tree_map = {} + parent_map = {} + for r in range(nslave): + tree_map[r] = self.get_neighbor(r, nslave) + parent_map[r] = (r + 1) / 2 - 1 + return tree_map, parent_map + def get_ring(self, tree_map, parent_map): + ring_map = {} + nslave = len(tree_map) + for r in range(nslave): + rprev = (r + nslave - 1) % nslave + rnext = (r + 1) % nslave + ring_map[r] = (rprev, rnext) + return ring_map def accept_slaves(self, nslave): + tree_map, parent_map = self.get_tree(nslave) + ring_map = self.get_ring(tree_map, parent_map) # set of nodes that finishs the job shutdown = {} # set of nodes that is waiting for connections @@ -163,7 +197,7 @@ class Tracker: rank = todo_nodes.pop(0) if s.jobid != 'NULL': job_map[s.jobid] = rank - s.assign_rank(rank, wait_conn, nslave) + s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) if s.wait_accept > 0: wait_conn[rank] = s print 'All nodes finishes job' diff --git a/src/socket.h b/src/socket.h index eba1b89f8..65516690d 100644 --- a/src/socket.h +++ b/src/socket.h @@ -153,7 +153,8 @@ class Socket { * \param end_port ending port number to try * \return the port successfully bind to, return -1 if failed to bind any port */ - inline int TryBindHost(int start_port, int end_port) { + inline int TryBindHost(int start_port, int end_port) { + // TODO, add prefix check for (int port = start_port; port < end_port; ++port) { SockAddr addr("0.0.0.0", port); if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0) { diff --git a/src/utils.h b/src/utils.h index d09667d89..e1b34fe2e 100644 --- a/src/utils.h +++ b/src/utils.h @@ -187,5 +187,13 @@ inline const T *BeginPtr(const std::vector &vec) { return &vec[0]; } } +inline char* BeginPtr(std::string &str) { + if (str.length() == 0) return NULL; + return &str[0]; +} +inline const char* BeginPtr(const std::string &str) { + if (str.length() == 0) return NULL; + return &str[0]; +} } // namespace rabit #endif // RABIT_UTILS_H_ diff --git a/submit_mpi.py b/submit_mpi.py index 3a65ec440..468604317 100755 --- a/submit_mpi.py +++ b/submit_mpi.py @@ -24,7 +24,10 @@ def mpi_submit(nslave, args): args arguments to launch each job this usually includes the parameters of master_uri and parameters passed into submit """ - cmd = ' '.join(['mpirun -n %d --hostfile %s' % (nslave, args[0])] + args[1:]) + if args[0] == 'local': + cmd = ' '.join(['mpirun -n %d' % (nslave)] + args[1:]) + else: + cmd = ' '.join(['mpirun -n %d --hostfile %s' % (nslave, args[0])] + args[1:]) print cmd subprocess.check_call(cmd, shell = True)