add ring setup version

This commit is contained in:
tqchen 2014-12-07 16:09:28 -08:00
parent 322e40c72e
commit e2adce1cc1
9 changed files with 334 additions and 101 deletions

View File

@ -7,6 +7,7 @@
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include <map>
#include <cstdlib>
#include <cstring>
#include "./allreduce_base.h"
@ -43,17 +44,18 @@ void AllreduceBase::Init(void) {
}
// start socket
utils::Socket::Startup();
utils::Assert(links.size() == 0, "can only call Init once");
utils::Assert(all_links.size() == 0, "can only call Init once");
this->host_uri = utils::SockAddr::GetHostName();
// get information from tracker
this->ReConnectLinks();
}
void AllreduceBase::Shutdown(void) {
for (size_t i = 0; i < links.size(); ++i) {
links[i].sock.Close();
for (size_t i = 0; i < all_links.size(); ++i) {
all_links[i].sock.Close();
}
links.clear();
all_links.clear();
tree_links.plinks.clear();
if (tracker_uri == "NULL") return;
int magic = kMagic;
@ -121,8 +123,12 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
tracker.SendStr(task_id);
tracker.SendStr(std::string(cmd));
// the rank of previous link, next link in ring
int prev_rank, next_rank;
// the rank of neighbors
std::map<int, int> tree_neighbors;
{// get new ranks
int newrank;
int newrank, num_neighbors;
utils::Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
"ReConnectLink failure 4");
utils::Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) == sizeof(parent_rank),
@ -130,8 +136,20 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
utils::Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
"ReConnectLink failure 4");
utils::Assert(rank == -1 || newrank == rank, "must keep rank to same if the node already have one");
rank = newrank;
}
rank = newrank;
utils::Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == sizeof(num_neighbors),
"ReConnectLink failure 4");
for (int i = 0; i < num_neighbors; ++i) {
int nrank;
utils::Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
"ReConnectLink failure 4");
tree_neighbors[nrank] = 1;
}
utils::Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
"ReConnectLink failure 4");
utils::Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
"ReConnectLink failure 4");
}
// create listening socket
utils::TCPSocket sock_listen;
sock_listen.Create();
@ -144,11 +162,11 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
do {
// send over good links
std::vector<int> good_link;
for (size_t i = 0; i < links.size(); ++i) {
if (!links[i].sock.BadSocket()) {
good_link.push_back(static_cast<int>(links[i].rank));
for (size_t i = 0; i < all_links.size(); ++i) {
if (!all_links[i].sock.BadSocket()) {
good_link.push_back(static_cast<int>(all_links[i].rank));
} else {
if (!links[i].sock.IsClosed()) links[i].sock.Close();
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
}
}
int ngood = static_cast<int>(good_link.size());
@ -178,13 +196,13 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 13");
utils::Check(hrank == r.rank, "ReConnectLink failure, link rank inconsistent");
bool match = false;
for (size_t i = 0; i < links.size(); ++i) {
if (links[i].rank == hrank) {
utils::Assert(links[i].sock.IsClosed(), "Override a link that is active");
links[i].sock = r.sock; match = true; break;
for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == hrank) {
utils::Assert(all_links[i].sock.IsClosed(), "Override a link that is active");
all_links[i].sock = r.sock; match = true; break;
}
}
if (!match) links.push_back(r);
if (!match) all_links.push_back(r);
}
utils::Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error), "ReConnectLink failure 14");
} while (num_error != 0);
@ -199,27 +217,35 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 15");
utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 15");
bool match = false;
for (size_t i = 0; i < links.size(); ++i) {
if (links[i].rank == r.rank) {
utils::Assert(links[i].sock.IsClosed(), "Override a link that is active");
links[i].sock = r.sock; match = true; break;
for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == r.rank) {
utils::Assert(all_links[i].sock.IsClosed(), "Override a link that is active");
all_links[i].sock = r.sock; match = true; break;
}
}
if (!match) links.push_back(r);
if (!match) all_links.push_back(r);
}
// close listening sockets
sock_listen.Close();
this->parent_index = -1;
// setup selecter
for (size_t i = 0; i < links.size(); ++i) {
utils::Assert(!links[i].sock.BadSocket(), "ReConnectLink: bad socket");
// setup tree links and ring structure
tree_links.plinks.clear();
for (size_t i = 0; i < all_links.size(); ++i) {
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
// set the socket to non-blocking mode
links[i].sock.SetNonBlock(true);
if (links[i].rank == parent_rank) parent_index = static_cast<int>(i);
}
if (parent_rank != -1) {
utils::Assert(parent_index != -1, "cannot find parent in the link");
all_links[i].sock.SetNonBlock(true);
if (tree_neighbors.count(all_links[i].rank) != 0) {
if (all_links[i].rank == parent_rank) {
parent_index = static_cast<int>(tree_links.plinks.size());
}
tree_links.plinks.push_back(&all_links[i]);
}
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
if (all_links[i].rank == next_rank) ring_next = &all_links[i];
}
utils::Assert(parent_rank == -1 || parent_index != -1, "cannot find parent in the link");
utils::Assert(prev_rank == -1 || ring_prev != NULL, "cannot find prev ring in the link");
utils::Assert(next_rank == -1 || ring_next != NULL, "cannot find next ring in the link");
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
@ -241,6 +267,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
RefLinkVector &links = tree_links;
if (links.size() == 0 || count == 0) return kSuccess;
// total size of message
const size_t total_size = type_nbytes * count;
@ -391,8 +418,9 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
*/
AllreduceBase::ReturnType
AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
RefLinkVector &links = tree_links;
if (links.size() == 0 || total_size == 0) return kSuccess;
utils::Check(root < world_size, "Broadcast: root should be smaller than world size");
utils::Check(root < world_size, "Broadcast: root should be smaller than world size");
// number of links
const int nlink = static_cast<int>(links.size());
// size of space already read from data

View File

@ -259,6 +259,19 @@ class AllreduceBase : public IEngine {
// aligned with 64 bits, will be able to perform 64 bits operations freely
std::vector<uint64_t> buffer_;
};
/*!
* \brief simple data structure that works like a vector
* but takes reference instead of space
*/
struct RefLinkVector {
std::vector<LinkRecord*> plinks;
inline LinkRecord &operator[](size_t i) {
return *plinks[i];
}
inline size_t size(void) const {
return plinks.size();
}
};
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
@ -306,9 +319,11 @@ class AllreduceBase : public IEngine {
int parent_index;
// rank of parent node, can be -1
int parent_rank;
// sockets of all links
std::vector<LinkRecord> links;
// pointer to someplace in the ring
// sockets of all links this connects to
std::vector<LinkRecord> all_links;
// all the links in the reduction tree connection
RefLinkVector tree_links;
// pointer to links in the ring
LinkRecord *ring_prev, *ring_next;
//----- meta information-----
// unique identifier of the possible job this process is doing

View File

@ -37,6 +37,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
const std::vector<EdgeType> &edge_in,
size_t out_index)
) {
RefLinkVector &links = tree_links;
if (links.size() == 0) return kSuccess;
// number of links
const int nlink = static_cast<int>(links.size());

View File

@ -11,13 +11,14 @@
#include <utility>
#include "./io.h"
#include "./utils.h"
#include "./rabit.h"
#include "./allreduce_robust.h"
namespace rabit {
namespace engine {
AllreduceRobust::AllreduceRobust(void) {
result_buffer_round = 1;
num_local_replica = 2;
num_local_replica = 0;
seq_counter = 0;
}
/*! \brief shutdown the engine */
@ -131,9 +132,17 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
*/
int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
utils::ISerializable *local_model) {
utils::Check(local_model == NULL, "CheckPoint local_model is not yet supported");
// check if we succesfll
if (num_local_replica == 0) {
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
}
// check if we succesful
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp)) {
if (local_model != NULL) {
// load in local model
utils::MemoryFixSizeBuffer fs(BeginPtr(local_chkpt[local_chkpt_version]),
local_rptr[local_chkpt_version][1]);
local_model->Load(fs);
}
// reset result buffer
resbuf.Clear(); seq_counter = 0;
// load from buffer
@ -170,7 +179,31 @@ int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
*/
void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model,
const utils::ISerializable *local_model) {
utils::Assert(local_model == NULL, "CheckPoint local model is not supported yet");
if (num_local_replica == 0) {
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
}
if (num_local_replica != 0) {
while (true) {
if (RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckPoint)) break;
// save model model to new version place
int new_version = !local_chkpt_version;
local_chkpt[new_version].clear();
utils::MemoryBufferStream fs(&local_chkpt[new_version]);
if (local_model != NULL) {
local_model->Save(fs);
}
local_rptr[new_version].clear();
local_rptr[new_version].push_back(0);
local_rptr[new_version].push_back(local_chkpt[new_version].length());
if (CheckAndRecover(TryCheckinLocalState(&local_rptr[new_version],
&local_chkpt[new_version]))) break;
}
// run the ack phase
utils::Assert(RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckAck),
"check point must return true");
// switch pointer to new version
local_chkpt_version = !local_chkpt_version;
}
// execute checkpoint, note: when checkpoint existing, load will not happen
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp),
"check point must return true");
@ -199,32 +232,32 @@ void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model,
*/
AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
// number of links
const int nlink = static_cast<int>(links.size());
const int nlink = static_cast<int>(all_links.size());
for (int i = 0; i < nlink; ++i) {
links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size);
links[i].ResetSize();
all_links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size);
all_links[i].ResetSize();
}
// read and discard data from all channels until pass mark
while (true) {
for (int i = 0; i < nlink; ++i) {
if (links[i].sock.BadSocket()) continue;
if (links[i].size_write == 0) {
if (all_links[i].sock.BadSocket()) continue;
if (all_links[i].size_write == 0) {
char sig = kOOBReset;
ssize_t len = links[i].sock.Send(&sig, sizeof(sig), MSG_OOB);
ssize_t len = all_links[i].sock.Send(&sig, sizeof(sig), MSG_OOB);
// error will be filtered in next loop
if (len == sizeof(sig)) links[i].size_write = 1;
if (len == sizeof(sig)) all_links[i].size_write = 1;
}
if (links[i].size_write == 1) {
if (all_links[i].size_write == 1) {
char sig = kResetMark;
ssize_t len = links[i].sock.Send(&sig, sizeof(sig));
if (len == sizeof(sig)) links[i].size_write = 2;
ssize_t len = all_links[i].sock.Send(&sig, sizeof(sig));
if (len == sizeof(sig)) all_links[i].size_write = 2;
}
}
utils::SelectHelper rsel;
bool finished = true;
for (int i = 0; i < nlink; ++i) {
if (links[i].size_write != 2 && !links[i].sock.BadSocket()) {
rsel.WatchWrite(links[i].sock); finished = false;
if (all_links[i].size_write != 2 && !all_links[i].sock.BadSocket()) {
rsel.WatchWrite(all_links[i].sock); finished = false;
}
}
if (finished) break;
@ -232,32 +265,32 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
rsel.Select();
}
for (int i = 0; i < nlink; ++i) {
if (!links[i].sock.BadSocket()) {
utils::SelectHelper::WaitExcept(links[i].sock);
if (!all_links[i].sock.BadSocket()) {
utils::SelectHelper::WaitExcept(all_links[i].sock);
}
}
while (true) {
for (int i = 0; i < nlink; ++i) {
if (links[i].size_read == 0) {
int atmark = links[i].sock.AtMark();
if (all_links[i].size_read == 0) {
int atmark = all_links[i].sock.AtMark();
if (atmark < 0) {
utils::Assert(links[i].sock.BadSocket(), "must already gone bad");
utils::Assert(all_links[i].sock.BadSocket(), "must already gone bad");
} else if (atmark > 0) {
links[i].size_read = 1;
all_links[i].size_read = 1;
} else {
// no at mark, read and discard data
ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size);
if (links[i].sock.AtMark()) links[i].size_read = 1;
ssize_t len = all_links[i].sock.Recv(all_links[i].buffer_head, all_links[i].buffer_size);
if (all_links[i].sock.AtMark()) all_links[i].size_read = 1;
// zero length, remote closed the connection, close socket
if (len == 0) links[i].sock.Close();
if (len == 0) all_links[i].sock.Close();
}
}
}
utils::SelectHelper rsel;
bool finished = true;
for (int i = 0; i < nlink; ++i) {
if (links[i].size_read == 0 && !links[i].sock.BadSocket()) {
rsel.WatchRead(links[i].sock); finished = false;
if (all_links[i].size_read == 0 && !all_links[i].sock.BadSocket()) {
rsel.WatchRead(all_links[i].sock); finished = false;
}
}
if (finished) break;
@ -266,22 +299,22 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
// start synchronization, use blocking I/O to avoid select
for (int i = 0; i < nlink; ++i) {
if (!links[i].sock.BadSocket()) {
if (!all_links[i].sock.BadSocket()) {
char oob_mark;
links[i].sock.SetNonBlock(false);
ssize_t len = links[i].sock.Recv(&oob_mark, sizeof(oob_mark), MSG_WAITALL);
all_links[i].sock.SetNonBlock(false);
ssize_t len = all_links[i].sock.Recv(&oob_mark, sizeof(oob_mark), MSG_WAITALL);
if (len == 0) {
links[i].sock.Close(); continue;
all_links[i].sock.Close(); continue;
} else if (len > 0) {
utils::Assert(oob_mark == kResetMark, "wrong oob msg");
utils::Assert(links[i].sock.AtMark() != 1, "should already read past mark");
utils::Assert(all_links[i].sock.AtMark() != 1, "should already read past mark");
} else {
utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
}
// send out ack
char ack = kResetAck;
while (true) {
len = links[i].sock.Send(&ack, sizeof(ack));
len = all_links[i].sock.Send(&ack, sizeof(ack));
if (len == sizeof(ack)) break;
if (len == -1) {
if (errno != EAGAIN && errno != EWOULDBLOCK) break;
@ -291,22 +324,22 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
}
// wait all ack
for (int i = 0; i < nlink; ++i) {
if (!links[i].sock.BadSocket()) {
if (!all_links[i].sock.BadSocket()) {
char ack;
ssize_t len = links[i].sock.Recv(&ack, sizeof(ack), MSG_WAITALL);
ssize_t len = all_links[i].sock.Recv(&ack, sizeof(ack), MSG_WAITALL);
if (len == 0) {
links[i].sock.Close(); continue;
all_links[i].sock.Close(); continue;
} else if (len > 0) {
utils::Assert(ack == kResetAck, "wrong Ack MSG");
} else {
utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
}
// set back to nonblock mode
links[i].sock.SetNonBlock(true);
all_links[i].sock.SetNonBlock(true);
}
}
for (int i = 0; i < nlink; ++i) {
if (links[i].sock.BadSocket()) return kSockError;
if (all_links[i].sock.BadSocket()) return kSockError;
}
return kSuccess;
}
@ -320,8 +353,8 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
if (err_type == kSuccess) return true;
// simple way, shutdown all links
for (size_t i = 0; i < links.size(); ++i) {
if (!links[i].sock.BadSocket()) links[i].sock.Close();
for (size_t i = 0; i < all_links.size(); ++i) {
if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close();
}
ReConnectLinks("recover");
return false;
@ -479,6 +512,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
size_t size,
int recv_link,
const std::vector<bool> &req_in) {
RefLinkVector &links = tree_links;
// no need to run recovery for zero size message
if (links.size() == 0 || size == 0) return kSuccess;
utils::Assert(req_in.size() == links.size(), "TryRecoverData");
@ -580,17 +614,48 @@ AllreduceRobust::TryRecoverData(RecoverType role,
* \sa ReturnType
*/
AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
RecoverType role = requester ? kRequestData : kHaveData;
// check in local data
RecoverType role = requester ? kRequestData : kHaveData;
ReturnType succ;
if (num_local_replica != 0) {
if (requester) {
// clear existing history, if any, before load
local_rptr[local_chkpt_version].clear();
local_chkpt[local_chkpt_version].clear();
}
// recover local checkpoint
succ = TryRecoverLocalState(&local_rptr[local_chkpt_version],
&local_chkpt[local_chkpt_version]);
if (succ != kSuccess) return succ;
int nlocal = std::max(static_cast<int>(local_rptr[local_chkpt_version].size()) - 1, 0);
// check if everyone is OK
unsigned state = 0;
if (nlocal == num_local_replica + 1) {
// complete recovery
state = 1;
} else if (nlocal == 0) {
// get nothing
state = 2;
} else {
// partially complete state
state = 4;
}
succ = TryAllreduce(&state, sizeof(state), 1, op::Reducer<op::BitOR, unsigned>);
if (succ != kSuccess) return succ;
utils::Check(state == 1 || state == 2,
"LoadCheckPoint: too many nodes fails, cannot recover local state");
}
// recover global checkpoint
size_t size = this->global_checkpoint.length();
int recv_link;
std::vector<bool> req_in;
ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in);
succ = TryDecideRouting(role, &size, &recv_link, &req_in);
if (succ != kSuccess) return succ;
if (role == kRequestData) {
global_checkpoint.resize(size);
}
if (size == 0) return kSuccess;
return TryRecoverData(role, &global_checkpoint[0], size, recv_link, req_in);
return TryRecoverData(role, BeginPtr(global_checkpoint), size, recv_link, req_in);
}
/*!
* \brief try to get the result of operation specified by seqno
@ -607,11 +672,21 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
* \sa ReturnType
*/
AllreduceRobust::ReturnType
AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) { RecoverType role;
AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) {
// if minimum sequence requested is local check point ack,
// this means all nodes have finished local check point, directly return
if (seqno == ActionSummary::kLocalCheckAck) return kSuccess;
if (seqno == ActionSummary::kLocalCheckPoint) {
// new version of local model
int new_version = !local_chkpt_version;
int nlocal = std::max(static_cast<int>(local_rptr[new_version].size()) - 1, 0);
// if we goes to this place, use must have already setup the state once
utils::Assert(nlocal == 1 || nlocal == num_local_replica + 1,
"TryGetResult::Checkpoint");
return TryRecoverLocalState(&local_rptr[new_version], &local_chkpt[new_version]);
}
// handles normal data recovery
RecoverType role;
if (!requester) {
sendrecvbuf = resbuf.Query(seqno, &size);
role = sendrecvbuf != NULL ? kHaveData : kPassData;
@ -786,7 +861,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
}
chkpt.resize(rptr.back());
// pass data through the link
succ = RingPassing(&chkpt[0], rptr[nlocal], rptr[nread_end],
succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end],
rptr[nwrite_start], rptr[nread_end],
ring_next, ring_prev);
if (succ != kSuccess) {
@ -849,7 +924,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
}
chkpt.resize(rptr.back());
// pass data through the link
succ = RingPassing(&chkpt[0], rptr[nlocal], rptr[nread_end],
succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end],
rptr[nwrite_start], rptr[nwrite_end],
ring_prev, ring_next);
if (succ != kSuccess) {
@ -858,6 +933,57 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
}
return kSuccess;
}
/*!
* \brief try to checkpoint local state, this function is called in normal executation phase
* of checkpoint that contains local state
* the input state must exactly one saved state(local state of current node),
* after complete, this function will get local state from previous num_local_replica nodes and put them
* into local_chkpt and local_rptr
*
* It is also OK to call TryRecoverLocalState instead,
* TryRecoverLocalState makes less assumption about the input, and requires more communications
*
* \param p_local_rptr the pointer to the segment pointers in the states array
* \param p_local_chkpt the pointer to the storage of local check points
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType, TryRecoverLocalState
*/
AllreduceRobust::ReturnType
AllreduceRobust::TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
std::string *p_local_chkpt) {
// if there is no local replica, we can do nothing
if (num_local_replica == 0) return kSuccess;
std::vector<size_t> &rptr = *p_local_rptr;
std::string &chkpt = *p_local_chkpt;
utils::Assert(rptr.size() == 2, "TryCheckinLocalState must have exactly 1 state");
const int n = num_local_replica;
std::vector<size_t> sizes(n + 1);
sizes[0] = rptr[1] - rptr[0];
ReturnType succ;
// pass size through the link
succ = RingPassing(BeginPtr(sizes),
1 * sizeof(size_t),
(n + 1) * sizeof(size_t),
0 * sizeof(size_t),
n * sizeof(size_t),
ring_prev, ring_next);
if (succ != kSuccess) return succ;
// update rptr
rptr.resize(n + 1);
for (int i = 1; i < n; ++i) {
rptr[i + 1] = rptr[i] + sizes[i];
}
chkpt.resize(rptr.back());
// pass data through the link
succ = RingPassing(BeginPtr(chkpt),
rptr[1], rptr[n + 1],
rptr[0], rptr[n],
ring_prev, ring_next);
if (succ != kSuccess) {
rptr.resize(2); chkpt.resize(rptr.back()); return succ;
}
return kSuccess;
}
/*!
* \brief perform a ring passing to receive data from prev link, and sent data to next link
* this allows data to stream over a ring structure
@ -883,7 +1009,7 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
size_t write_end,
LinkRecord *read_link,
LinkRecord *write_link) {
if (links.size() == 0 || read_end == 0) return kSuccess;
if (read_link == NULL || write_link == NULL || read_end == 0) return kSuccess;
utils::Assert(read_end <= write_end, "boundary check");
utils::Assert(read_ptr <= read_end, "boundary check");
utils::Assert(write_ptr <= write_end, "boundary check");

View File

@ -372,6 +372,23 @@ class AllreduceRobust : public AllreduceBase {
*/
ReturnType TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
std::string *p_local_chkpt);
/*!
* \brief try to checkpoint local state, this function is called in normal executation phase
* of checkpoint that contains local state
o * the input state must exactly one saved state(local state of current node),
* after complete, this function will get local state from previous num_local_replica nodes and put them
* into local_chkpt and local_rptr
*
* It is also OK to call TryRecoverLocalState instead,
* TryRecoverLocalState makes less assumption about the input, and requires more communications
*
* \param p_local_rptr the pointer to the segment pointers in the states array
* \param p_local_chkpt the pointer to the storage of local check points
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType, TryRecoverLocalState
*/
ReturnType TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
std::string *p_local_chkpt);
/*!
* \brief perform a ring passing to receive data from prev link, and sent data to next link
* this allows data to stream over a ring structure
@ -441,7 +458,7 @@ class AllreduceRobust : public AllreduceBase {
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops in the ring
std::vector<size_t> local_rptr[2];
// storage for local model replicas
std::string local_checkpoint[2];
std::string local_chkpt[2];
// version of local checkpoint can be 1 or 0
int local_chkpt_version;
};

View File

@ -63,25 +63,32 @@ class SlaveEntry:
return job_map[self.jobid]
return -1
def get_neighbor(self, rank, nslave):
rank = rank + 1
ret = []
if rank > 1:
ret.append(rank / 2 - 1)
if rank * 2 - 1 < nslave:
ret.append(rank * 2 - 1)
if rank * 2 < nslave:
ret.append(rank * 2)
return set(ret)
def assign_rank(self, rank, wait_conn, nslave):
def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map):
self.rank = rank
nnset = self.get_neighbor(rank, nslave)
nnset = set(tree_map[rank])
rprev, rnext = ring_map[rank]
self.sock.sendint(rank)
# send parent rank
self.sock.sendint((rank + 1) / 2 - 1)
self.sock.sendint(parent_map[rank])
# send world size
self.sock.sendint(nslave)
self.sock.sendint(len(tree_map))
self.sock.sendint(len(nnset))
# send the rprev and next link
for r in nnset:
self.sock.sendint(r)
# send prev link
if rprev != -1 and rprev != rank:
nnset.add(rprev)
self.sock.sendint(rprev)
else:
self.sock.sendint(-1)
# send next link
if rnext != -1 and rnext != rank:
nnset.add(rnext)
self.sock.sendint(rnext)
else:
self.sock.sendint(-1)
while True:
ngood = self.sock.recvint()
goodset = set([])
@ -131,8 +138,35 @@ class Tracker:
self.sock.close()
def slave_args(self):
return ['rabit_tracker_uri=%s' % socket.gethostname(),
'rabit_tracker_port=%s' % self.port]
'rabit_tracker_port=%s' % self.port]
def get_neighbor(self, rank, nslave):
rank = rank + 1
ret = []
if rank > 1:
ret.append(rank / 2 - 1)
if rank * 2 - 1 < nslave:
ret.append(rank * 2 - 1)
if rank * 2 < nslave:
ret.append(rank * 2)
return ret
def get_tree(self, nslave):
tree_map = {}
parent_map = {}
for r in range(nslave):
tree_map[r] = self.get_neighbor(r, nslave)
parent_map[r] = (r + 1) / 2 - 1
return tree_map, parent_map
def get_ring(self, tree_map, parent_map):
ring_map = {}
nslave = len(tree_map)
for r in range(nslave):
rprev = (r + nslave - 1) % nslave
rnext = (r + 1) % nslave
ring_map[r] = (rprev, rnext)
return ring_map
def accept_slaves(self, nslave):
tree_map, parent_map = self.get_tree(nslave)
ring_map = self.get_ring(tree_map, parent_map)
# set of nodes that finishs the job
shutdown = {}
# set of nodes that is waiting for connections
@ -163,7 +197,7 @@ class Tracker:
rank = todo_nodes.pop(0)
if s.jobid != 'NULL':
job_map[s.jobid] = rank
s.assign_rank(rank, wait_conn, nslave)
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
if s.wait_accept > 0:
wait_conn[rank] = s
print 'All nodes finishes job'

View File

@ -153,7 +153,8 @@ class Socket {
* \param end_port ending port number to try
* \return the port successfully bind to, return -1 if failed to bind any port
*/
inline int TryBindHost(int start_port, int end_port) {
inline int TryBindHost(int start_port, int end_port) {
// TODO, add prefix check
for (int port = start_port; port < end_port; ++port) {
SockAddr addr("0.0.0.0", port);
if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0) {

View File

@ -187,5 +187,13 @@ inline const T *BeginPtr(const std::vector<T> &vec) {
return &vec[0];
}
}
inline char* BeginPtr(std::string &str) {
if (str.length() == 0) return NULL;
return &str[0];
}
inline const char* BeginPtr(const std::string &str) {
if (str.length() == 0) return NULL;
return &str[0];
}
} // namespace rabit
#endif // RABIT_UTILS_H_

View File

@ -24,7 +24,10 @@ def mpi_submit(nslave, args):
args arguments to launch each job
this usually includes the parameters of master_uri and parameters passed into submit
"""
cmd = ' '.join(['mpirun -n %d --hostfile %s' % (nslave, args[0])] + args[1:])
if args[0] == 'local':
cmd = ' '.join(['mpirun -n %d' % (nslave)] + args[1:])
else:
cmd = ' '.join(['mpirun -n %d --hostfile %s' % (nslave, args[0])] + args[1:])
print cmd
subprocess.check_call(cmd, shell = True)