add ring setup version
This commit is contained in:
parent
322e40c72e
commit
e2adce1cc1
@ -7,6 +7,7 @@
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#define NOMINMAX
|
||||
#include <map>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include "./allreduce_base.h"
|
||||
@ -43,17 +44,18 @@ void AllreduceBase::Init(void) {
|
||||
}
|
||||
// start socket
|
||||
utils::Socket::Startup();
|
||||
utils::Assert(links.size() == 0, "can only call Init once");
|
||||
utils::Assert(all_links.size() == 0, "can only call Init once");
|
||||
this->host_uri = utils::SockAddr::GetHostName();
|
||||
// get information from tracker
|
||||
this->ReConnectLinks();
|
||||
}
|
||||
|
||||
void AllreduceBase::Shutdown(void) {
|
||||
for (size_t i = 0; i < links.size(); ++i) {
|
||||
links[i].sock.Close();
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
all_links[i].sock.Close();
|
||||
}
|
||||
links.clear();
|
||||
all_links.clear();
|
||||
tree_links.plinks.clear();
|
||||
|
||||
if (tracker_uri == "NULL") return;
|
||||
int magic = kMagic;
|
||||
@ -121,8 +123,12 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
|
||||
tracker.SendStr(task_id);
|
||||
tracker.SendStr(std::string(cmd));
|
||||
// the rank of previous link, next link in ring
|
||||
int prev_rank, next_rank;
|
||||
// the rank of neighbors
|
||||
std::map<int, int> tree_neighbors;
|
||||
{// get new ranks
|
||||
int newrank;
|
||||
int newrank, num_neighbors;
|
||||
utils::Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
|
||||
"ReConnectLink failure 4");
|
||||
utils::Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) == sizeof(parent_rank),
|
||||
@ -130,8 +136,20 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
utils::Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
|
||||
"ReConnectLink failure 4");
|
||||
utils::Assert(rank == -1 || newrank == rank, "must keep rank to same if the node already have one");
|
||||
rank = newrank;
|
||||
}
|
||||
rank = newrank;
|
||||
utils::Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == sizeof(num_neighbors),
|
||||
"ReConnectLink failure 4");
|
||||
for (int i = 0; i < num_neighbors; ++i) {
|
||||
int nrank;
|
||||
utils::Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
|
||||
"ReConnectLink failure 4");
|
||||
tree_neighbors[nrank] = 1;
|
||||
}
|
||||
utils::Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
|
||||
"ReConnectLink failure 4");
|
||||
utils::Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
|
||||
"ReConnectLink failure 4");
|
||||
}
|
||||
// create listening socket
|
||||
utils::TCPSocket sock_listen;
|
||||
sock_listen.Create();
|
||||
@ -144,11 +162,11 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
do {
|
||||
// send over good links
|
||||
std::vector<int> good_link;
|
||||
for (size_t i = 0; i < links.size(); ++i) {
|
||||
if (!links[i].sock.BadSocket()) {
|
||||
good_link.push_back(static_cast<int>(links[i].rank));
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
if (!all_links[i].sock.BadSocket()) {
|
||||
good_link.push_back(static_cast<int>(all_links[i].rank));
|
||||
} else {
|
||||
if (!links[i].sock.IsClosed()) links[i].sock.Close();
|
||||
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
|
||||
}
|
||||
}
|
||||
int ngood = static_cast<int>(good_link.size());
|
||||
@ -178,13 +196,13 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 13");
|
||||
utils::Check(hrank == r.rank, "ReConnectLink failure, link rank inconsistent");
|
||||
bool match = false;
|
||||
for (size_t i = 0; i < links.size(); ++i) {
|
||||
if (links[i].rank == hrank) {
|
||||
utils::Assert(links[i].sock.IsClosed(), "Override a link that is active");
|
||||
links[i].sock = r.sock; match = true; break;
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
if (all_links[i].rank == hrank) {
|
||||
utils::Assert(all_links[i].sock.IsClosed(), "Override a link that is active");
|
||||
all_links[i].sock = r.sock; match = true; break;
|
||||
}
|
||||
}
|
||||
if (!match) links.push_back(r);
|
||||
if (!match) all_links.push_back(r);
|
||||
}
|
||||
utils::Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error), "ReConnectLink failure 14");
|
||||
} while (num_error != 0);
|
||||
@ -199,27 +217,35 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 15");
|
||||
utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 15");
|
||||
bool match = false;
|
||||
for (size_t i = 0; i < links.size(); ++i) {
|
||||
if (links[i].rank == r.rank) {
|
||||
utils::Assert(links[i].sock.IsClosed(), "Override a link that is active");
|
||||
links[i].sock = r.sock; match = true; break;
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
if (all_links[i].rank == r.rank) {
|
||||
utils::Assert(all_links[i].sock.IsClosed(), "Override a link that is active");
|
||||
all_links[i].sock = r.sock; match = true; break;
|
||||
}
|
||||
}
|
||||
if (!match) links.push_back(r);
|
||||
if (!match) all_links.push_back(r);
|
||||
}
|
||||
// close listening sockets
|
||||
sock_listen.Close();
|
||||
this->parent_index = -1;
|
||||
// setup selecter
|
||||
for (size_t i = 0; i < links.size(); ++i) {
|
||||
utils::Assert(!links[i].sock.BadSocket(), "ReConnectLink: bad socket");
|
||||
// setup tree links and ring structure
|
||||
tree_links.plinks.clear();
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
|
||||
// set the socket to non-blocking mode
|
||||
links[i].sock.SetNonBlock(true);
|
||||
if (links[i].rank == parent_rank) parent_index = static_cast<int>(i);
|
||||
}
|
||||
if (parent_rank != -1) {
|
||||
utils::Assert(parent_index != -1, "cannot find parent in the link");
|
||||
all_links[i].sock.SetNonBlock(true);
|
||||
if (tree_neighbors.count(all_links[i].rank) != 0) {
|
||||
if (all_links[i].rank == parent_rank) {
|
||||
parent_index = static_cast<int>(tree_links.plinks.size());
|
||||
}
|
||||
tree_links.plinks.push_back(&all_links[i]);
|
||||
}
|
||||
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
|
||||
if (all_links[i].rank == next_rank) ring_next = &all_links[i];
|
||||
}
|
||||
utils::Assert(parent_rank == -1 || parent_index != -1, "cannot find parent in the link");
|
||||
utils::Assert(prev_rank == -1 || ring_prev != NULL, "cannot find prev ring in the link");
|
||||
utils::Assert(next_rank == -1 || ring_next != NULL, "cannot find next ring in the link");
|
||||
}
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
||||
@ -241,6 +267,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer) {
|
||||
RefLinkVector &links = tree_links;
|
||||
if (links.size() == 0 || count == 0) return kSuccess;
|
||||
// total size of message
|
||||
const size_t total_size = type_nbytes * count;
|
||||
@ -391,8 +418,9 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
||||
*/
|
||||
AllreduceBase::ReturnType
|
||||
AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
RefLinkVector &links = tree_links;
|
||||
if (links.size() == 0 || total_size == 0) return kSuccess;
|
||||
utils::Check(root < world_size, "Broadcast: root should be smaller than world size");
|
||||
utils::Check(root < world_size, "Broadcast: root should be smaller than world size");
|
||||
// number of links
|
||||
const int nlink = static_cast<int>(links.size());
|
||||
// size of space already read from data
|
||||
|
||||
@ -259,6 +259,19 @@ class AllreduceBase : public IEngine {
|
||||
// aligned with 64 bits, will be able to perform 64 bits operations freely
|
||||
std::vector<uint64_t> buffer_;
|
||||
};
|
||||
/*!
|
||||
* \brief simple data structure that works like a vector
|
||||
* but takes reference instead of space
|
||||
*/
|
||||
struct RefLinkVector {
|
||||
std::vector<LinkRecord*> plinks;
|
||||
inline LinkRecord &operator[](size_t i) {
|
||||
return *plinks[i];
|
||||
}
|
||||
inline size_t size(void) const {
|
||||
return plinks.size();
|
||||
}
|
||||
};
|
||||
/*!
|
||||
* \brief connect to the tracker to fix the the missing links
|
||||
* this function is also used when the engine start up
|
||||
@ -306,9 +319,11 @@ class AllreduceBase : public IEngine {
|
||||
int parent_index;
|
||||
// rank of parent node, can be -1
|
||||
int parent_rank;
|
||||
// sockets of all links
|
||||
std::vector<LinkRecord> links;
|
||||
// pointer to someplace in the ring
|
||||
// sockets of all links this connects to
|
||||
std::vector<LinkRecord> all_links;
|
||||
// all the links in the reduction tree connection
|
||||
RefLinkVector tree_links;
|
||||
// pointer to links in the ring
|
||||
LinkRecord *ring_prev, *ring_next;
|
||||
//----- meta information-----
|
||||
// unique identifier of the possible job this process is doing
|
||||
|
||||
@ -37,6 +37,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
|
||||
const std::vector<EdgeType> &edge_in,
|
||||
size_t out_index)
|
||||
) {
|
||||
RefLinkVector &links = tree_links;
|
||||
if (links.size() == 0) return kSuccess;
|
||||
// number of links
|
||||
const int nlink = static_cast<int>(links.size());
|
||||
|
||||
@ -11,13 +11,14 @@
|
||||
#include <utility>
|
||||
#include "./io.h"
|
||||
#include "./utils.h"
|
||||
#include "./rabit.h"
|
||||
#include "./allreduce_robust.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
AllreduceRobust::AllreduceRobust(void) {
|
||||
result_buffer_round = 1;
|
||||
num_local_replica = 2;
|
||||
num_local_replica = 0;
|
||||
seq_counter = 0;
|
||||
}
|
||||
/*! \brief shutdown the engine */
|
||||
@ -131,9 +132,17 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
|
||||
*/
|
||||
int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
|
||||
utils::ISerializable *local_model) {
|
||||
utils::Check(local_model == NULL, "CheckPoint local_model is not yet supported");
|
||||
// check if we succesfll
|
||||
if (num_local_replica == 0) {
|
||||
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
|
||||
}
|
||||
// check if we succesful
|
||||
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp)) {
|
||||
if (local_model != NULL) {
|
||||
// load in local model
|
||||
utils::MemoryFixSizeBuffer fs(BeginPtr(local_chkpt[local_chkpt_version]),
|
||||
local_rptr[local_chkpt_version][1]);
|
||||
local_model->Load(fs);
|
||||
}
|
||||
// reset result buffer
|
||||
resbuf.Clear(); seq_counter = 0;
|
||||
// load from buffer
|
||||
@ -170,7 +179,31 @@ int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
|
||||
*/
|
||||
void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model,
|
||||
const utils::ISerializable *local_model) {
|
||||
utils::Assert(local_model == NULL, "CheckPoint local model is not supported yet");
|
||||
if (num_local_replica == 0) {
|
||||
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
|
||||
}
|
||||
if (num_local_replica != 0) {
|
||||
while (true) {
|
||||
if (RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckPoint)) break;
|
||||
// save model model to new version place
|
||||
int new_version = !local_chkpt_version;
|
||||
local_chkpt[new_version].clear();
|
||||
utils::MemoryBufferStream fs(&local_chkpt[new_version]);
|
||||
if (local_model != NULL) {
|
||||
local_model->Save(fs);
|
||||
}
|
||||
local_rptr[new_version].clear();
|
||||
local_rptr[new_version].push_back(0);
|
||||
local_rptr[new_version].push_back(local_chkpt[new_version].length());
|
||||
if (CheckAndRecover(TryCheckinLocalState(&local_rptr[new_version],
|
||||
&local_chkpt[new_version]))) break;
|
||||
}
|
||||
// run the ack phase
|
||||
utils::Assert(RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckAck),
|
||||
"check point must return true");
|
||||
// switch pointer to new version
|
||||
local_chkpt_version = !local_chkpt_version;
|
||||
}
|
||||
// execute checkpoint, note: when checkpoint existing, load will not happen
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp),
|
||||
"check point must return true");
|
||||
@ -199,32 +232,32 @@ void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model,
|
||||
*/
|
||||
AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
||||
// number of links
|
||||
const int nlink = static_cast<int>(links.size());
|
||||
const int nlink = static_cast<int>(all_links.size());
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size);
|
||||
links[i].ResetSize();
|
||||
all_links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size);
|
||||
all_links[i].ResetSize();
|
||||
}
|
||||
// read and discard data from all channels until pass mark
|
||||
while (true) {
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (links[i].sock.BadSocket()) continue;
|
||||
if (links[i].size_write == 0) {
|
||||
if (all_links[i].sock.BadSocket()) continue;
|
||||
if (all_links[i].size_write == 0) {
|
||||
char sig = kOOBReset;
|
||||
ssize_t len = links[i].sock.Send(&sig, sizeof(sig), MSG_OOB);
|
||||
ssize_t len = all_links[i].sock.Send(&sig, sizeof(sig), MSG_OOB);
|
||||
// error will be filtered in next loop
|
||||
if (len == sizeof(sig)) links[i].size_write = 1;
|
||||
if (len == sizeof(sig)) all_links[i].size_write = 1;
|
||||
}
|
||||
if (links[i].size_write == 1) {
|
||||
if (all_links[i].size_write == 1) {
|
||||
char sig = kResetMark;
|
||||
ssize_t len = links[i].sock.Send(&sig, sizeof(sig));
|
||||
if (len == sizeof(sig)) links[i].size_write = 2;
|
||||
ssize_t len = all_links[i].sock.Send(&sig, sizeof(sig));
|
||||
if (len == sizeof(sig)) all_links[i].size_write = 2;
|
||||
}
|
||||
}
|
||||
utils::SelectHelper rsel;
|
||||
bool finished = true;
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (links[i].size_write != 2 && !links[i].sock.BadSocket()) {
|
||||
rsel.WatchWrite(links[i].sock); finished = false;
|
||||
if (all_links[i].size_write != 2 && !all_links[i].sock.BadSocket()) {
|
||||
rsel.WatchWrite(all_links[i].sock); finished = false;
|
||||
}
|
||||
}
|
||||
if (finished) break;
|
||||
@ -232,32 +265,32 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
||||
rsel.Select();
|
||||
}
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (!links[i].sock.BadSocket()) {
|
||||
utils::SelectHelper::WaitExcept(links[i].sock);
|
||||
if (!all_links[i].sock.BadSocket()) {
|
||||
utils::SelectHelper::WaitExcept(all_links[i].sock);
|
||||
}
|
||||
}
|
||||
while (true) {
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (links[i].size_read == 0) {
|
||||
int atmark = links[i].sock.AtMark();
|
||||
if (all_links[i].size_read == 0) {
|
||||
int atmark = all_links[i].sock.AtMark();
|
||||
if (atmark < 0) {
|
||||
utils::Assert(links[i].sock.BadSocket(), "must already gone bad");
|
||||
utils::Assert(all_links[i].sock.BadSocket(), "must already gone bad");
|
||||
} else if (atmark > 0) {
|
||||
links[i].size_read = 1;
|
||||
all_links[i].size_read = 1;
|
||||
} else {
|
||||
// no at mark, read and discard data
|
||||
ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size);
|
||||
if (links[i].sock.AtMark()) links[i].size_read = 1;
|
||||
ssize_t len = all_links[i].sock.Recv(all_links[i].buffer_head, all_links[i].buffer_size);
|
||||
if (all_links[i].sock.AtMark()) all_links[i].size_read = 1;
|
||||
// zero length, remote closed the connection, close socket
|
||||
if (len == 0) links[i].sock.Close();
|
||||
if (len == 0) all_links[i].sock.Close();
|
||||
}
|
||||
}
|
||||
}
|
||||
utils::SelectHelper rsel;
|
||||
bool finished = true;
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (links[i].size_read == 0 && !links[i].sock.BadSocket()) {
|
||||
rsel.WatchRead(links[i].sock); finished = false;
|
||||
if (all_links[i].size_read == 0 && !all_links[i].sock.BadSocket()) {
|
||||
rsel.WatchRead(all_links[i].sock); finished = false;
|
||||
}
|
||||
}
|
||||
if (finished) break;
|
||||
@ -266,22 +299,22 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
||||
|
||||
// start synchronization, use blocking I/O to avoid select
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (!links[i].sock.BadSocket()) {
|
||||
if (!all_links[i].sock.BadSocket()) {
|
||||
char oob_mark;
|
||||
links[i].sock.SetNonBlock(false);
|
||||
ssize_t len = links[i].sock.Recv(&oob_mark, sizeof(oob_mark), MSG_WAITALL);
|
||||
all_links[i].sock.SetNonBlock(false);
|
||||
ssize_t len = all_links[i].sock.Recv(&oob_mark, sizeof(oob_mark), MSG_WAITALL);
|
||||
if (len == 0) {
|
||||
links[i].sock.Close(); continue;
|
||||
all_links[i].sock.Close(); continue;
|
||||
} else if (len > 0) {
|
||||
utils::Assert(oob_mark == kResetMark, "wrong oob msg");
|
||||
utils::Assert(links[i].sock.AtMark() != 1, "should already read past mark");
|
||||
utils::Assert(all_links[i].sock.AtMark() != 1, "should already read past mark");
|
||||
} else {
|
||||
utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
|
||||
}
|
||||
// send out ack
|
||||
char ack = kResetAck;
|
||||
while (true) {
|
||||
len = links[i].sock.Send(&ack, sizeof(ack));
|
||||
len = all_links[i].sock.Send(&ack, sizeof(ack));
|
||||
if (len == sizeof(ack)) break;
|
||||
if (len == -1) {
|
||||
if (errno != EAGAIN && errno != EWOULDBLOCK) break;
|
||||
@ -291,22 +324,22 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
||||
}
|
||||
// wait all ack
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (!links[i].sock.BadSocket()) {
|
||||
if (!all_links[i].sock.BadSocket()) {
|
||||
char ack;
|
||||
ssize_t len = links[i].sock.Recv(&ack, sizeof(ack), MSG_WAITALL);
|
||||
ssize_t len = all_links[i].sock.Recv(&ack, sizeof(ack), MSG_WAITALL);
|
||||
if (len == 0) {
|
||||
links[i].sock.Close(); continue;
|
||||
all_links[i].sock.Close(); continue;
|
||||
} else if (len > 0) {
|
||||
utils::Assert(ack == kResetAck, "wrong Ack MSG");
|
||||
} else {
|
||||
utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
|
||||
}
|
||||
// set back to nonblock mode
|
||||
links[i].sock.SetNonBlock(true);
|
||||
all_links[i].sock.SetNonBlock(true);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (links[i].sock.BadSocket()) return kSockError;
|
||||
if (all_links[i].sock.BadSocket()) return kSockError;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
@ -320,8 +353,8 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
||||
bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
|
||||
if (err_type == kSuccess) return true;
|
||||
// simple way, shutdown all links
|
||||
for (size_t i = 0; i < links.size(); ++i) {
|
||||
if (!links[i].sock.BadSocket()) links[i].sock.Close();
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close();
|
||||
}
|
||||
ReConnectLinks("recover");
|
||||
return false;
|
||||
@ -479,6 +512,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
||||
size_t size,
|
||||
int recv_link,
|
||||
const std::vector<bool> &req_in) {
|
||||
RefLinkVector &links = tree_links;
|
||||
// no need to run recovery for zero size message
|
||||
if (links.size() == 0 || size == 0) return kSuccess;
|
||||
utils::Assert(req_in.size() == links.size(), "TryRecoverData");
|
||||
@ -580,17 +614,48 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
||||
* \sa ReturnType
|
||||
*/
|
||||
AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
|
||||
RecoverType role = requester ? kRequestData : kHaveData;
|
||||
// check in local data
|
||||
RecoverType role = requester ? kRequestData : kHaveData;
|
||||
ReturnType succ;
|
||||
if (num_local_replica != 0) {
|
||||
if (requester) {
|
||||
// clear existing history, if any, before load
|
||||
local_rptr[local_chkpt_version].clear();
|
||||
local_chkpt[local_chkpt_version].clear();
|
||||
}
|
||||
// recover local checkpoint
|
||||
succ = TryRecoverLocalState(&local_rptr[local_chkpt_version],
|
||||
&local_chkpt[local_chkpt_version]);
|
||||
if (succ != kSuccess) return succ;
|
||||
int nlocal = std::max(static_cast<int>(local_rptr[local_chkpt_version].size()) - 1, 0);
|
||||
// check if everyone is OK
|
||||
unsigned state = 0;
|
||||
if (nlocal == num_local_replica + 1) {
|
||||
// complete recovery
|
||||
state = 1;
|
||||
} else if (nlocal == 0) {
|
||||
// get nothing
|
||||
state = 2;
|
||||
} else {
|
||||
// partially complete state
|
||||
state = 4;
|
||||
}
|
||||
succ = TryAllreduce(&state, sizeof(state), 1, op::Reducer<op::BitOR, unsigned>);
|
||||
if (succ != kSuccess) return succ;
|
||||
utils::Check(state == 1 || state == 2,
|
||||
"LoadCheckPoint: too many nodes fails, cannot recover local state");
|
||||
}
|
||||
// recover global checkpoint
|
||||
size_t size = this->global_checkpoint.length();
|
||||
int recv_link;
|
||||
std::vector<bool> req_in;
|
||||
ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in);
|
||||
succ = TryDecideRouting(role, &size, &recv_link, &req_in);
|
||||
if (succ != kSuccess) return succ;
|
||||
if (role == kRequestData) {
|
||||
global_checkpoint.resize(size);
|
||||
}
|
||||
if (size == 0) return kSuccess;
|
||||
return TryRecoverData(role, &global_checkpoint[0], size, recv_link, req_in);
|
||||
return TryRecoverData(role, BeginPtr(global_checkpoint), size, recv_link, req_in);
|
||||
}
|
||||
/*!
|
||||
* \brief try to get the result of operation specified by seqno
|
||||
@ -607,11 +672,21 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
|
||||
* \sa ReturnType
|
||||
*/
|
||||
AllreduceRobust::ReturnType
|
||||
AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) { RecoverType role;
|
||||
AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) {
|
||||
// if minimum sequence requested is local check point ack,
|
||||
// this means all nodes have finished local check point, directly return
|
||||
if (seqno == ActionSummary::kLocalCheckAck) return kSuccess;
|
||||
|
||||
if (seqno == ActionSummary::kLocalCheckPoint) {
|
||||
// new version of local model
|
||||
int new_version = !local_chkpt_version;
|
||||
int nlocal = std::max(static_cast<int>(local_rptr[new_version].size()) - 1, 0);
|
||||
// if we goes to this place, use must have already setup the state once
|
||||
utils::Assert(nlocal == 1 || nlocal == num_local_replica + 1,
|
||||
"TryGetResult::Checkpoint");
|
||||
return TryRecoverLocalState(&local_rptr[new_version], &local_chkpt[new_version]);
|
||||
}
|
||||
// handles normal data recovery
|
||||
RecoverType role;
|
||||
if (!requester) {
|
||||
sendrecvbuf = resbuf.Query(seqno, &size);
|
||||
role = sendrecvbuf != NULL ? kHaveData : kPassData;
|
||||
@ -786,7 +861,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
||||
}
|
||||
chkpt.resize(rptr.back());
|
||||
// pass data through the link
|
||||
succ = RingPassing(&chkpt[0], rptr[nlocal], rptr[nread_end],
|
||||
succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end],
|
||||
rptr[nwrite_start], rptr[nread_end],
|
||||
ring_next, ring_prev);
|
||||
if (succ != kSuccess) {
|
||||
@ -849,7 +924,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
||||
}
|
||||
chkpt.resize(rptr.back());
|
||||
// pass data through the link
|
||||
succ = RingPassing(&chkpt[0], rptr[nlocal], rptr[nread_end],
|
||||
succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end],
|
||||
rptr[nwrite_start], rptr[nwrite_end],
|
||||
ring_prev, ring_next);
|
||||
if (succ != kSuccess) {
|
||||
@ -858,6 +933,57 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
/*!
|
||||
* \brief try to checkpoint local state, this function is called in normal executation phase
|
||||
* of checkpoint that contains local state
|
||||
* the input state must exactly one saved state(local state of current node),
|
||||
* after complete, this function will get local state from previous num_local_replica nodes and put them
|
||||
* into local_chkpt and local_rptr
|
||||
*
|
||||
* It is also OK to call TryRecoverLocalState instead,
|
||||
* TryRecoverLocalState makes less assumption about the input, and requires more communications
|
||||
*
|
||||
* \param p_local_rptr the pointer to the segment pointers in the states array
|
||||
* \param p_local_chkpt the pointer to the storage of local check points
|
||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType, TryRecoverLocalState
|
||||
*/
|
||||
AllreduceRobust::ReturnType
|
||||
AllreduceRobust::TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
|
||||
std::string *p_local_chkpt) {
|
||||
// if there is no local replica, we can do nothing
|
||||
if (num_local_replica == 0) return kSuccess;
|
||||
std::vector<size_t> &rptr = *p_local_rptr;
|
||||
std::string &chkpt = *p_local_chkpt;
|
||||
utils::Assert(rptr.size() == 2, "TryCheckinLocalState must have exactly 1 state");
|
||||
const int n = num_local_replica;
|
||||
std::vector<size_t> sizes(n + 1);
|
||||
sizes[0] = rptr[1] - rptr[0];
|
||||
ReturnType succ;
|
||||
// pass size through the link
|
||||
succ = RingPassing(BeginPtr(sizes),
|
||||
1 * sizeof(size_t),
|
||||
(n + 1) * sizeof(size_t),
|
||||
0 * sizeof(size_t),
|
||||
n * sizeof(size_t),
|
||||
ring_prev, ring_next);
|
||||
if (succ != kSuccess) return succ;
|
||||
// update rptr
|
||||
rptr.resize(n + 1);
|
||||
for (int i = 1; i < n; ++i) {
|
||||
rptr[i + 1] = rptr[i] + sizes[i];
|
||||
}
|
||||
chkpt.resize(rptr.back());
|
||||
// pass data through the link
|
||||
succ = RingPassing(BeginPtr(chkpt),
|
||||
rptr[1], rptr[n + 1],
|
||||
rptr[0], rptr[n],
|
||||
ring_prev, ring_next);
|
||||
if (succ != kSuccess) {
|
||||
rptr.resize(2); chkpt.resize(rptr.back()); return succ;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
/*!
|
||||
* \brief perform a ring passing to receive data from prev link, and sent data to next link
|
||||
* this allows data to stream over a ring structure
|
||||
@ -883,7 +1009,7 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
|
||||
size_t write_end,
|
||||
LinkRecord *read_link,
|
||||
LinkRecord *write_link) {
|
||||
if (links.size() == 0 || read_end == 0) return kSuccess;
|
||||
if (read_link == NULL || write_link == NULL || read_end == 0) return kSuccess;
|
||||
utils::Assert(read_end <= write_end, "boundary check");
|
||||
utils::Assert(read_ptr <= read_end, "boundary check");
|
||||
utils::Assert(write_ptr <= write_end, "boundary check");
|
||||
|
||||
@ -372,6 +372,23 @@ class AllreduceRobust : public AllreduceBase {
|
||||
*/
|
||||
ReturnType TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
||||
std::string *p_local_chkpt);
|
||||
/*!
|
||||
* \brief try to checkpoint local state, this function is called in normal executation phase
|
||||
* of checkpoint that contains local state
|
||||
o * the input state must exactly one saved state(local state of current node),
|
||||
* after complete, this function will get local state from previous num_local_replica nodes and put them
|
||||
* into local_chkpt and local_rptr
|
||||
*
|
||||
* It is also OK to call TryRecoverLocalState instead,
|
||||
* TryRecoverLocalState makes less assumption about the input, and requires more communications
|
||||
*
|
||||
* \param p_local_rptr the pointer to the segment pointers in the states array
|
||||
* \param p_local_chkpt the pointer to the storage of local check points
|
||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType, TryRecoverLocalState
|
||||
*/
|
||||
ReturnType TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
|
||||
std::string *p_local_chkpt);
|
||||
/*!
|
||||
* \brief perform a ring passing to receive data from prev link, and sent data to next link
|
||||
* this allows data to stream over a ring structure
|
||||
@ -441,7 +458,7 @@ class AllreduceRobust : public AllreduceBase {
|
||||
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops in the ring
|
||||
std::vector<size_t> local_rptr[2];
|
||||
// storage for local model replicas
|
||||
std::string local_checkpoint[2];
|
||||
std::string local_chkpt[2];
|
||||
// version of local checkpoint can be 1 or 0
|
||||
int local_chkpt_version;
|
||||
};
|
||||
|
||||
@ -63,25 +63,32 @@ class SlaveEntry:
|
||||
return job_map[self.jobid]
|
||||
return -1
|
||||
|
||||
def get_neighbor(self, rank, nslave):
|
||||
rank = rank + 1
|
||||
ret = []
|
||||
if rank > 1:
|
||||
ret.append(rank / 2 - 1)
|
||||
if rank * 2 - 1 < nslave:
|
||||
ret.append(rank * 2 - 1)
|
||||
if rank * 2 < nslave:
|
||||
ret.append(rank * 2)
|
||||
return set(ret)
|
||||
|
||||
def assign_rank(self, rank, wait_conn, nslave):
|
||||
def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map):
|
||||
self.rank = rank
|
||||
nnset = self.get_neighbor(rank, nslave)
|
||||
nnset = set(tree_map[rank])
|
||||
rprev, rnext = ring_map[rank]
|
||||
self.sock.sendint(rank)
|
||||
# send parent rank
|
||||
self.sock.sendint((rank + 1) / 2 - 1)
|
||||
self.sock.sendint(parent_map[rank])
|
||||
# send world size
|
||||
self.sock.sendint(nslave)
|
||||
self.sock.sendint(len(tree_map))
|
||||
self.sock.sendint(len(nnset))
|
||||
# send the rprev and next link
|
||||
for r in nnset:
|
||||
self.sock.sendint(r)
|
||||
# send prev link
|
||||
if rprev != -1 and rprev != rank:
|
||||
nnset.add(rprev)
|
||||
self.sock.sendint(rprev)
|
||||
else:
|
||||
self.sock.sendint(-1)
|
||||
# send next link
|
||||
if rnext != -1 and rnext != rank:
|
||||
nnset.add(rnext)
|
||||
self.sock.sendint(rnext)
|
||||
else:
|
||||
self.sock.sendint(-1)
|
||||
|
||||
while True:
|
||||
ngood = self.sock.recvint()
|
||||
goodset = set([])
|
||||
@ -131,8 +138,35 @@ class Tracker:
|
||||
self.sock.close()
|
||||
def slave_args(self):
|
||||
return ['rabit_tracker_uri=%s' % socket.gethostname(),
|
||||
'rabit_tracker_port=%s' % self.port]
|
||||
'rabit_tracker_port=%s' % self.port]
|
||||
def get_neighbor(self, rank, nslave):
|
||||
rank = rank + 1
|
||||
ret = []
|
||||
if rank > 1:
|
||||
ret.append(rank / 2 - 1)
|
||||
if rank * 2 - 1 < nslave:
|
||||
ret.append(rank * 2 - 1)
|
||||
if rank * 2 < nslave:
|
||||
ret.append(rank * 2)
|
||||
return ret
|
||||
def get_tree(self, nslave):
|
||||
tree_map = {}
|
||||
parent_map = {}
|
||||
for r in range(nslave):
|
||||
tree_map[r] = self.get_neighbor(r, nslave)
|
||||
parent_map[r] = (r + 1) / 2 - 1
|
||||
return tree_map, parent_map
|
||||
def get_ring(self, tree_map, parent_map):
|
||||
ring_map = {}
|
||||
nslave = len(tree_map)
|
||||
for r in range(nslave):
|
||||
rprev = (r + nslave - 1) % nslave
|
||||
rnext = (r + 1) % nslave
|
||||
ring_map[r] = (rprev, rnext)
|
||||
return ring_map
|
||||
def accept_slaves(self, nslave):
|
||||
tree_map, parent_map = self.get_tree(nslave)
|
||||
ring_map = self.get_ring(tree_map, parent_map)
|
||||
# set of nodes that finishs the job
|
||||
shutdown = {}
|
||||
# set of nodes that is waiting for connections
|
||||
@ -163,7 +197,7 @@ class Tracker:
|
||||
rank = todo_nodes.pop(0)
|
||||
if s.jobid != 'NULL':
|
||||
job_map[s.jobid] = rank
|
||||
s.assign_rank(rank, wait_conn, nslave)
|
||||
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||
if s.wait_accept > 0:
|
||||
wait_conn[rank] = s
|
||||
print 'All nodes finishes job'
|
||||
|
||||
@ -153,7 +153,8 @@ class Socket {
|
||||
* \param end_port ending port number to try
|
||||
* \return the port successfully bind to, return -1 if failed to bind any port
|
||||
*/
|
||||
inline int TryBindHost(int start_port, int end_port) {
|
||||
inline int TryBindHost(int start_port, int end_port) {
|
||||
// TODO, add prefix check
|
||||
for (int port = start_port; port < end_port; ++port) {
|
||||
SockAddr addr("0.0.0.0", port);
|
||||
if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0) {
|
||||
|
||||
@ -187,5 +187,13 @@ inline const T *BeginPtr(const std::vector<T> &vec) {
|
||||
return &vec[0];
|
||||
}
|
||||
}
|
||||
inline char* BeginPtr(std::string &str) {
|
||||
if (str.length() == 0) return NULL;
|
||||
return &str[0];
|
||||
}
|
||||
inline const char* BeginPtr(const std::string &str) {
|
||||
if (str.length() == 0) return NULL;
|
||||
return &str[0];
|
||||
}
|
||||
} // namespace rabit
|
||||
#endif // RABIT_UTILS_H_
|
||||
|
||||
@ -24,7 +24,10 @@ def mpi_submit(nslave, args):
|
||||
args arguments to launch each job
|
||||
this usually includes the parameters of master_uri and parameters passed into submit
|
||||
"""
|
||||
cmd = ' '.join(['mpirun -n %d --hostfile %s' % (nslave, args[0])] + args[1:])
|
||||
if args[0] == 'local':
|
||||
cmd = ' '.join(['mpirun -n %d' % (nslave)] + args[1:])
|
||||
else:
|
||||
cmd = ' '.join(['mpirun -n %d --hostfile %s' % (nslave, args[0])] + args[1:])
|
||||
print cmd
|
||||
subprocess.check_call(cmd, shell = True)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user