add ring setup version
This commit is contained in:
parent
322e40c72e
commit
e2adce1cc1
@ -7,6 +7,7 @@
|
|||||||
#define _CRT_SECURE_NO_WARNINGS
|
#define _CRT_SECURE_NO_WARNINGS
|
||||||
#define _CRT_SECURE_NO_DEPRECATE
|
#define _CRT_SECURE_NO_DEPRECATE
|
||||||
#define NOMINMAX
|
#define NOMINMAX
|
||||||
|
#include <map>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include "./allreduce_base.h"
|
#include "./allreduce_base.h"
|
||||||
@ -43,17 +44,18 @@ void AllreduceBase::Init(void) {
|
|||||||
}
|
}
|
||||||
// start socket
|
// start socket
|
||||||
utils::Socket::Startup();
|
utils::Socket::Startup();
|
||||||
utils::Assert(links.size() == 0, "can only call Init once");
|
utils::Assert(all_links.size() == 0, "can only call Init once");
|
||||||
this->host_uri = utils::SockAddr::GetHostName();
|
this->host_uri = utils::SockAddr::GetHostName();
|
||||||
// get information from tracker
|
// get information from tracker
|
||||||
this->ReConnectLinks();
|
this->ReConnectLinks();
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllreduceBase::Shutdown(void) {
|
void AllreduceBase::Shutdown(void) {
|
||||||
for (size_t i = 0; i < links.size(); ++i) {
|
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||||
links[i].sock.Close();
|
all_links[i].sock.Close();
|
||||||
}
|
}
|
||||||
links.clear();
|
all_links.clear();
|
||||||
|
tree_links.plinks.clear();
|
||||||
|
|
||||||
if (tracker_uri == "NULL") return;
|
if (tracker_uri == "NULL") return;
|
||||||
int magic = kMagic;
|
int magic = kMagic;
|
||||||
@ -121,8 +123,12 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
|
utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
|
||||||
tracker.SendStr(task_id);
|
tracker.SendStr(task_id);
|
||||||
tracker.SendStr(std::string(cmd));
|
tracker.SendStr(std::string(cmd));
|
||||||
|
// the rank of previous link, next link in ring
|
||||||
|
int prev_rank, next_rank;
|
||||||
|
// the rank of neighbors
|
||||||
|
std::map<int, int> tree_neighbors;
|
||||||
{// get new ranks
|
{// get new ranks
|
||||||
int newrank;
|
int newrank, num_neighbors;
|
||||||
utils::Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
|
utils::Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
|
||||||
"ReConnectLink failure 4");
|
"ReConnectLink failure 4");
|
||||||
utils::Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) == sizeof(parent_rank),
|
utils::Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) == sizeof(parent_rank),
|
||||||
@ -130,8 +136,20 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
utils::Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
|
utils::Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
|
||||||
"ReConnectLink failure 4");
|
"ReConnectLink failure 4");
|
||||||
utils::Assert(rank == -1 || newrank == rank, "must keep rank to same if the node already have one");
|
utils::Assert(rank == -1 || newrank == rank, "must keep rank to same if the node already have one");
|
||||||
rank = newrank;
|
rank = newrank;
|
||||||
}
|
utils::Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == sizeof(num_neighbors),
|
||||||
|
"ReConnectLink failure 4");
|
||||||
|
for (int i = 0; i < num_neighbors; ++i) {
|
||||||
|
int nrank;
|
||||||
|
utils::Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
|
||||||
|
"ReConnectLink failure 4");
|
||||||
|
tree_neighbors[nrank] = 1;
|
||||||
|
}
|
||||||
|
utils::Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
|
||||||
|
"ReConnectLink failure 4");
|
||||||
|
utils::Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
|
||||||
|
"ReConnectLink failure 4");
|
||||||
|
}
|
||||||
// create listening socket
|
// create listening socket
|
||||||
utils::TCPSocket sock_listen;
|
utils::TCPSocket sock_listen;
|
||||||
sock_listen.Create();
|
sock_listen.Create();
|
||||||
@ -144,11 +162,11 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
do {
|
do {
|
||||||
// send over good links
|
// send over good links
|
||||||
std::vector<int> good_link;
|
std::vector<int> good_link;
|
||||||
for (size_t i = 0; i < links.size(); ++i) {
|
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||||
if (!links[i].sock.BadSocket()) {
|
if (!all_links[i].sock.BadSocket()) {
|
||||||
good_link.push_back(static_cast<int>(links[i].rank));
|
good_link.push_back(static_cast<int>(all_links[i].rank));
|
||||||
} else {
|
} else {
|
||||||
if (!links[i].sock.IsClosed()) links[i].sock.Close();
|
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int ngood = static_cast<int>(good_link.size());
|
int ngood = static_cast<int>(good_link.size());
|
||||||
@ -178,13 +196,13 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 13");
|
utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 13");
|
||||||
utils::Check(hrank == r.rank, "ReConnectLink failure, link rank inconsistent");
|
utils::Check(hrank == r.rank, "ReConnectLink failure, link rank inconsistent");
|
||||||
bool match = false;
|
bool match = false;
|
||||||
for (size_t i = 0; i < links.size(); ++i) {
|
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||||
if (links[i].rank == hrank) {
|
if (all_links[i].rank == hrank) {
|
||||||
utils::Assert(links[i].sock.IsClosed(), "Override a link that is active");
|
utils::Assert(all_links[i].sock.IsClosed(), "Override a link that is active");
|
||||||
links[i].sock = r.sock; match = true; break;
|
all_links[i].sock = r.sock; match = true; break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!match) links.push_back(r);
|
if (!match) all_links.push_back(r);
|
||||||
}
|
}
|
||||||
utils::Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error), "ReConnectLink failure 14");
|
utils::Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error), "ReConnectLink failure 14");
|
||||||
} while (num_error != 0);
|
} while (num_error != 0);
|
||||||
@ -199,27 +217,35 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 15");
|
utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 15");
|
||||||
utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 15");
|
utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 15");
|
||||||
bool match = false;
|
bool match = false;
|
||||||
for (size_t i = 0; i < links.size(); ++i) {
|
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||||
if (links[i].rank == r.rank) {
|
if (all_links[i].rank == r.rank) {
|
||||||
utils::Assert(links[i].sock.IsClosed(), "Override a link that is active");
|
utils::Assert(all_links[i].sock.IsClosed(), "Override a link that is active");
|
||||||
links[i].sock = r.sock; match = true; break;
|
all_links[i].sock = r.sock; match = true; break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!match) links.push_back(r);
|
if (!match) all_links.push_back(r);
|
||||||
}
|
}
|
||||||
// close listening sockets
|
// close listening sockets
|
||||||
sock_listen.Close();
|
sock_listen.Close();
|
||||||
this->parent_index = -1;
|
this->parent_index = -1;
|
||||||
// setup selecter
|
// setup tree links and ring structure
|
||||||
for (size_t i = 0; i < links.size(); ++i) {
|
tree_links.plinks.clear();
|
||||||
utils::Assert(!links[i].sock.BadSocket(), "ReConnectLink: bad socket");
|
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||||
|
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
|
||||||
// set the socket to non-blocking mode
|
// set the socket to non-blocking mode
|
||||||
links[i].sock.SetNonBlock(true);
|
all_links[i].sock.SetNonBlock(true);
|
||||||
if (links[i].rank == parent_rank) parent_index = static_cast<int>(i);
|
if (tree_neighbors.count(all_links[i].rank) != 0) {
|
||||||
}
|
if (all_links[i].rank == parent_rank) {
|
||||||
if (parent_rank != -1) {
|
parent_index = static_cast<int>(tree_links.plinks.size());
|
||||||
utils::Assert(parent_index != -1, "cannot find parent in the link");
|
}
|
||||||
|
tree_links.plinks.push_back(&all_links[i]);
|
||||||
|
}
|
||||||
|
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
|
||||||
|
if (all_links[i].rank == next_rank) ring_next = &all_links[i];
|
||||||
}
|
}
|
||||||
|
utils::Assert(parent_rank == -1 || parent_index != -1, "cannot find parent in the link");
|
||||||
|
utils::Assert(prev_rank == -1 || ring_prev != NULL, "cannot find prev ring in the link");
|
||||||
|
utils::Assert(next_rank == -1 || ring_next != NULL, "cannot find next ring in the link");
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
||||||
@ -241,6 +267,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
|||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
ReduceFunction reducer) {
|
ReduceFunction reducer) {
|
||||||
|
RefLinkVector &links = tree_links;
|
||||||
if (links.size() == 0 || count == 0) return kSuccess;
|
if (links.size() == 0 || count == 0) return kSuccess;
|
||||||
// total size of message
|
// total size of message
|
||||||
const size_t total_size = type_nbytes * count;
|
const size_t total_size = type_nbytes * count;
|
||||||
@ -391,8 +418,9 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
|||||||
*/
|
*/
|
||||||
AllreduceBase::ReturnType
|
AllreduceBase::ReturnType
|
||||||
AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||||
|
RefLinkVector &links = tree_links;
|
||||||
if (links.size() == 0 || total_size == 0) return kSuccess;
|
if (links.size() == 0 || total_size == 0) return kSuccess;
|
||||||
utils::Check(root < world_size, "Broadcast: root should be smaller than world size");
|
utils::Check(root < world_size, "Broadcast: root should be smaller than world size");
|
||||||
// number of links
|
// number of links
|
||||||
const int nlink = static_cast<int>(links.size());
|
const int nlink = static_cast<int>(links.size());
|
||||||
// size of space already read from data
|
// size of space already read from data
|
||||||
|
|||||||
@ -259,6 +259,19 @@ class AllreduceBase : public IEngine {
|
|||||||
// aligned with 64 bits, will be able to perform 64 bits operations freely
|
// aligned with 64 bits, will be able to perform 64 bits operations freely
|
||||||
std::vector<uint64_t> buffer_;
|
std::vector<uint64_t> buffer_;
|
||||||
};
|
};
|
||||||
|
/*!
|
||||||
|
* \brief simple data structure that works like a vector
|
||||||
|
* but takes reference instead of space
|
||||||
|
*/
|
||||||
|
struct RefLinkVector {
|
||||||
|
std::vector<LinkRecord*> plinks;
|
||||||
|
inline LinkRecord &operator[](size_t i) {
|
||||||
|
return *plinks[i];
|
||||||
|
}
|
||||||
|
inline size_t size(void) const {
|
||||||
|
return plinks.size();
|
||||||
|
}
|
||||||
|
};
|
||||||
/*!
|
/*!
|
||||||
* \brief connect to the tracker to fix the the missing links
|
* \brief connect to the tracker to fix the the missing links
|
||||||
* this function is also used when the engine start up
|
* this function is also used when the engine start up
|
||||||
@ -306,9 +319,11 @@ class AllreduceBase : public IEngine {
|
|||||||
int parent_index;
|
int parent_index;
|
||||||
// rank of parent node, can be -1
|
// rank of parent node, can be -1
|
||||||
int parent_rank;
|
int parent_rank;
|
||||||
// sockets of all links
|
// sockets of all links this connects to
|
||||||
std::vector<LinkRecord> links;
|
std::vector<LinkRecord> all_links;
|
||||||
// pointer to someplace in the ring
|
// all the links in the reduction tree connection
|
||||||
|
RefLinkVector tree_links;
|
||||||
|
// pointer to links in the ring
|
||||||
LinkRecord *ring_prev, *ring_next;
|
LinkRecord *ring_prev, *ring_next;
|
||||||
//----- meta information-----
|
//----- meta information-----
|
||||||
// unique identifier of the possible job this process is doing
|
// unique identifier of the possible job this process is doing
|
||||||
|
|||||||
@ -37,6 +37,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
|
|||||||
const std::vector<EdgeType> &edge_in,
|
const std::vector<EdgeType> &edge_in,
|
||||||
size_t out_index)
|
size_t out_index)
|
||||||
) {
|
) {
|
||||||
|
RefLinkVector &links = tree_links;
|
||||||
if (links.size() == 0) return kSuccess;
|
if (links.size() == 0) return kSuccess;
|
||||||
// number of links
|
// number of links
|
||||||
const int nlink = static_cast<int>(links.size());
|
const int nlink = static_cast<int>(links.size());
|
||||||
|
|||||||
@ -11,13 +11,14 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include "./io.h"
|
#include "./io.h"
|
||||||
#include "./utils.h"
|
#include "./utils.h"
|
||||||
|
#include "./rabit.h"
|
||||||
#include "./allreduce_robust.h"
|
#include "./allreduce_robust.h"
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace engine {
|
namespace engine {
|
||||||
AllreduceRobust::AllreduceRobust(void) {
|
AllreduceRobust::AllreduceRobust(void) {
|
||||||
result_buffer_round = 1;
|
result_buffer_round = 1;
|
||||||
num_local_replica = 2;
|
num_local_replica = 0;
|
||||||
seq_counter = 0;
|
seq_counter = 0;
|
||||||
}
|
}
|
||||||
/*! \brief shutdown the engine */
|
/*! \brief shutdown the engine */
|
||||||
@ -131,9 +132,17 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
|
|||||||
*/
|
*/
|
||||||
int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
|
int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
|
||||||
utils::ISerializable *local_model) {
|
utils::ISerializable *local_model) {
|
||||||
utils::Check(local_model == NULL, "CheckPoint local_model is not yet supported");
|
if (num_local_replica == 0) {
|
||||||
// check if we succesfll
|
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
|
||||||
|
}
|
||||||
|
// check if we succesful
|
||||||
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp)) {
|
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp)) {
|
||||||
|
if (local_model != NULL) {
|
||||||
|
// load in local model
|
||||||
|
utils::MemoryFixSizeBuffer fs(BeginPtr(local_chkpt[local_chkpt_version]),
|
||||||
|
local_rptr[local_chkpt_version][1]);
|
||||||
|
local_model->Load(fs);
|
||||||
|
}
|
||||||
// reset result buffer
|
// reset result buffer
|
||||||
resbuf.Clear(); seq_counter = 0;
|
resbuf.Clear(); seq_counter = 0;
|
||||||
// load from buffer
|
// load from buffer
|
||||||
@ -170,7 +179,31 @@ int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
|
|||||||
*/
|
*/
|
||||||
void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model,
|
void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model,
|
||||||
const utils::ISerializable *local_model) {
|
const utils::ISerializable *local_model) {
|
||||||
utils::Assert(local_model == NULL, "CheckPoint local model is not supported yet");
|
if (num_local_replica == 0) {
|
||||||
|
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
|
||||||
|
}
|
||||||
|
if (num_local_replica != 0) {
|
||||||
|
while (true) {
|
||||||
|
if (RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckPoint)) break;
|
||||||
|
// save model model to new version place
|
||||||
|
int new_version = !local_chkpt_version;
|
||||||
|
local_chkpt[new_version].clear();
|
||||||
|
utils::MemoryBufferStream fs(&local_chkpt[new_version]);
|
||||||
|
if (local_model != NULL) {
|
||||||
|
local_model->Save(fs);
|
||||||
|
}
|
||||||
|
local_rptr[new_version].clear();
|
||||||
|
local_rptr[new_version].push_back(0);
|
||||||
|
local_rptr[new_version].push_back(local_chkpt[new_version].length());
|
||||||
|
if (CheckAndRecover(TryCheckinLocalState(&local_rptr[new_version],
|
||||||
|
&local_chkpt[new_version]))) break;
|
||||||
|
}
|
||||||
|
// run the ack phase
|
||||||
|
utils::Assert(RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckAck),
|
||||||
|
"check point must return true");
|
||||||
|
// switch pointer to new version
|
||||||
|
local_chkpt_version = !local_chkpt_version;
|
||||||
|
}
|
||||||
// execute checkpoint, note: when checkpoint existing, load will not happen
|
// execute checkpoint, note: when checkpoint existing, load will not happen
|
||||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp),
|
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp),
|
||||||
"check point must return true");
|
"check point must return true");
|
||||||
@ -199,32 +232,32 @@ void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model,
|
|||||||
*/
|
*/
|
||||||
AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
||||||
// number of links
|
// number of links
|
||||||
const int nlink = static_cast<int>(links.size());
|
const int nlink = static_cast<int>(all_links.size());
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size);
|
all_links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size);
|
||||||
links[i].ResetSize();
|
all_links[i].ResetSize();
|
||||||
}
|
}
|
||||||
// read and discard data from all channels until pass mark
|
// read and discard data from all channels until pass mark
|
||||||
while (true) {
|
while (true) {
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (links[i].sock.BadSocket()) continue;
|
if (all_links[i].sock.BadSocket()) continue;
|
||||||
if (links[i].size_write == 0) {
|
if (all_links[i].size_write == 0) {
|
||||||
char sig = kOOBReset;
|
char sig = kOOBReset;
|
||||||
ssize_t len = links[i].sock.Send(&sig, sizeof(sig), MSG_OOB);
|
ssize_t len = all_links[i].sock.Send(&sig, sizeof(sig), MSG_OOB);
|
||||||
// error will be filtered in next loop
|
// error will be filtered in next loop
|
||||||
if (len == sizeof(sig)) links[i].size_write = 1;
|
if (len == sizeof(sig)) all_links[i].size_write = 1;
|
||||||
}
|
}
|
||||||
if (links[i].size_write == 1) {
|
if (all_links[i].size_write == 1) {
|
||||||
char sig = kResetMark;
|
char sig = kResetMark;
|
||||||
ssize_t len = links[i].sock.Send(&sig, sizeof(sig));
|
ssize_t len = all_links[i].sock.Send(&sig, sizeof(sig));
|
||||||
if (len == sizeof(sig)) links[i].size_write = 2;
|
if (len == sizeof(sig)) all_links[i].size_write = 2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
utils::SelectHelper rsel;
|
utils::SelectHelper rsel;
|
||||||
bool finished = true;
|
bool finished = true;
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (links[i].size_write != 2 && !links[i].sock.BadSocket()) {
|
if (all_links[i].size_write != 2 && !all_links[i].sock.BadSocket()) {
|
||||||
rsel.WatchWrite(links[i].sock); finished = false;
|
rsel.WatchWrite(all_links[i].sock); finished = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (finished) break;
|
if (finished) break;
|
||||||
@ -232,32 +265,32 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
|||||||
rsel.Select();
|
rsel.Select();
|
||||||
}
|
}
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (!links[i].sock.BadSocket()) {
|
if (!all_links[i].sock.BadSocket()) {
|
||||||
utils::SelectHelper::WaitExcept(links[i].sock);
|
utils::SelectHelper::WaitExcept(all_links[i].sock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
while (true) {
|
while (true) {
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (links[i].size_read == 0) {
|
if (all_links[i].size_read == 0) {
|
||||||
int atmark = links[i].sock.AtMark();
|
int atmark = all_links[i].sock.AtMark();
|
||||||
if (atmark < 0) {
|
if (atmark < 0) {
|
||||||
utils::Assert(links[i].sock.BadSocket(), "must already gone bad");
|
utils::Assert(all_links[i].sock.BadSocket(), "must already gone bad");
|
||||||
} else if (atmark > 0) {
|
} else if (atmark > 0) {
|
||||||
links[i].size_read = 1;
|
all_links[i].size_read = 1;
|
||||||
} else {
|
} else {
|
||||||
// no at mark, read and discard data
|
// no at mark, read and discard data
|
||||||
ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size);
|
ssize_t len = all_links[i].sock.Recv(all_links[i].buffer_head, all_links[i].buffer_size);
|
||||||
if (links[i].sock.AtMark()) links[i].size_read = 1;
|
if (all_links[i].sock.AtMark()) all_links[i].size_read = 1;
|
||||||
// zero length, remote closed the connection, close socket
|
// zero length, remote closed the connection, close socket
|
||||||
if (len == 0) links[i].sock.Close();
|
if (len == 0) all_links[i].sock.Close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
utils::SelectHelper rsel;
|
utils::SelectHelper rsel;
|
||||||
bool finished = true;
|
bool finished = true;
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (links[i].size_read == 0 && !links[i].sock.BadSocket()) {
|
if (all_links[i].size_read == 0 && !all_links[i].sock.BadSocket()) {
|
||||||
rsel.WatchRead(links[i].sock); finished = false;
|
rsel.WatchRead(all_links[i].sock); finished = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (finished) break;
|
if (finished) break;
|
||||||
@ -266,22 +299,22 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
|||||||
|
|
||||||
// start synchronization, use blocking I/O to avoid select
|
// start synchronization, use blocking I/O to avoid select
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (!links[i].sock.BadSocket()) {
|
if (!all_links[i].sock.BadSocket()) {
|
||||||
char oob_mark;
|
char oob_mark;
|
||||||
links[i].sock.SetNonBlock(false);
|
all_links[i].sock.SetNonBlock(false);
|
||||||
ssize_t len = links[i].sock.Recv(&oob_mark, sizeof(oob_mark), MSG_WAITALL);
|
ssize_t len = all_links[i].sock.Recv(&oob_mark, sizeof(oob_mark), MSG_WAITALL);
|
||||||
if (len == 0) {
|
if (len == 0) {
|
||||||
links[i].sock.Close(); continue;
|
all_links[i].sock.Close(); continue;
|
||||||
} else if (len > 0) {
|
} else if (len > 0) {
|
||||||
utils::Assert(oob_mark == kResetMark, "wrong oob msg");
|
utils::Assert(oob_mark == kResetMark, "wrong oob msg");
|
||||||
utils::Assert(links[i].sock.AtMark() != 1, "should already read past mark");
|
utils::Assert(all_links[i].sock.AtMark() != 1, "should already read past mark");
|
||||||
} else {
|
} else {
|
||||||
utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
|
utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
|
||||||
}
|
}
|
||||||
// send out ack
|
// send out ack
|
||||||
char ack = kResetAck;
|
char ack = kResetAck;
|
||||||
while (true) {
|
while (true) {
|
||||||
len = links[i].sock.Send(&ack, sizeof(ack));
|
len = all_links[i].sock.Send(&ack, sizeof(ack));
|
||||||
if (len == sizeof(ack)) break;
|
if (len == sizeof(ack)) break;
|
||||||
if (len == -1) {
|
if (len == -1) {
|
||||||
if (errno != EAGAIN && errno != EWOULDBLOCK) break;
|
if (errno != EAGAIN && errno != EWOULDBLOCK) break;
|
||||||
@ -291,22 +324,22 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
|||||||
}
|
}
|
||||||
// wait all ack
|
// wait all ack
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (!links[i].sock.BadSocket()) {
|
if (!all_links[i].sock.BadSocket()) {
|
||||||
char ack;
|
char ack;
|
||||||
ssize_t len = links[i].sock.Recv(&ack, sizeof(ack), MSG_WAITALL);
|
ssize_t len = all_links[i].sock.Recv(&ack, sizeof(ack), MSG_WAITALL);
|
||||||
if (len == 0) {
|
if (len == 0) {
|
||||||
links[i].sock.Close(); continue;
|
all_links[i].sock.Close(); continue;
|
||||||
} else if (len > 0) {
|
} else if (len > 0) {
|
||||||
utils::Assert(ack == kResetAck, "wrong Ack MSG");
|
utils::Assert(ack == kResetAck, "wrong Ack MSG");
|
||||||
} else {
|
} else {
|
||||||
utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
|
utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
|
||||||
}
|
}
|
||||||
// set back to nonblock mode
|
// set back to nonblock mode
|
||||||
links[i].sock.SetNonBlock(true);
|
all_links[i].sock.SetNonBlock(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (links[i].sock.BadSocket()) return kSockError;
|
if (all_links[i].sock.BadSocket()) return kSockError;
|
||||||
}
|
}
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
@ -320,8 +353,8 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
|||||||
bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
|
bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
|
||||||
if (err_type == kSuccess) return true;
|
if (err_type == kSuccess) return true;
|
||||||
// simple way, shutdown all links
|
// simple way, shutdown all links
|
||||||
for (size_t i = 0; i < links.size(); ++i) {
|
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||||
if (!links[i].sock.BadSocket()) links[i].sock.Close();
|
if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close();
|
||||||
}
|
}
|
||||||
ReConnectLinks("recover");
|
ReConnectLinks("recover");
|
||||||
return false;
|
return false;
|
||||||
@ -479,6 +512,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
|||||||
size_t size,
|
size_t size,
|
||||||
int recv_link,
|
int recv_link,
|
||||||
const std::vector<bool> &req_in) {
|
const std::vector<bool> &req_in) {
|
||||||
|
RefLinkVector &links = tree_links;
|
||||||
// no need to run recovery for zero size message
|
// no need to run recovery for zero size message
|
||||||
if (links.size() == 0 || size == 0) return kSuccess;
|
if (links.size() == 0 || size == 0) return kSuccess;
|
||||||
utils::Assert(req_in.size() == links.size(), "TryRecoverData");
|
utils::Assert(req_in.size() == links.size(), "TryRecoverData");
|
||||||
@ -580,17 +614,48 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
|||||||
* \sa ReturnType
|
* \sa ReturnType
|
||||||
*/
|
*/
|
||||||
AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
|
AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
|
||||||
RecoverType role = requester ? kRequestData : kHaveData;
|
// check in local data
|
||||||
|
RecoverType role = requester ? kRequestData : kHaveData;
|
||||||
|
ReturnType succ;
|
||||||
|
if (num_local_replica != 0) {
|
||||||
|
if (requester) {
|
||||||
|
// clear existing history, if any, before load
|
||||||
|
local_rptr[local_chkpt_version].clear();
|
||||||
|
local_chkpt[local_chkpt_version].clear();
|
||||||
|
}
|
||||||
|
// recover local checkpoint
|
||||||
|
succ = TryRecoverLocalState(&local_rptr[local_chkpt_version],
|
||||||
|
&local_chkpt[local_chkpt_version]);
|
||||||
|
if (succ != kSuccess) return succ;
|
||||||
|
int nlocal = std::max(static_cast<int>(local_rptr[local_chkpt_version].size()) - 1, 0);
|
||||||
|
// check if everyone is OK
|
||||||
|
unsigned state = 0;
|
||||||
|
if (nlocal == num_local_replica + 1) {
|
||||||
|
// complete recovery
|
||||||
|
state = 1;
|
||||||
|
} else if (nlocal == 0) {
|
||||||
|
// get nothing
|
||||||
|
state = 2;
|
||||||
|
} else {
|
||||||
|
// partially complete state
|
||||||
|
state = 4;
|
||||||
|
}
|
||||||
|
succ = TryAllreduce(&state, sizeof(state), 1, op::Reducer<op::BitOR, unsigned>);
|
||||||
|
if (succ != kSuccess) return succ;
|
||||||
|
utils::Check(state == 1 || state == 2,
|
||||||
|
"LoadCheckPoint: too many nodes fails, cannot recover local state");
|
||||||
|
}
|
||||||
|
// recover global checkpoint
|
||||||
size_t size = this->global_checkpoint.length();
|
size_t size = this->global_checkpoint.length();
|
||||||
int recv_link;
|
int recv_link;
|
||||||
std::vector<bool> req_in;
|
std::vector<bool> req_in;
|
||||||
ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in);
|
succ = TryDecideRouting(role, &size, &recv_link, &req_in);
|
||||||
if (succ != kSuccess) return succ;
|
if (succ != kSuccess) return succ;
|
||||||
if (role == kRequestData) {
|
if (role == kRequestData) {
|
||||||
global_checkpoint.resize(size);
|
global_checkpoint.resize(size);
|
||||||
}
|
}
|
||||||
if (size == 0) return kSuccess;
|
if (size == 0) return kSuccess;
|
||||||
return TryRecoverData(role, &global_checkpoint[0], size, recv_link, req_in);
|
return TryRecoverData(role, BeginPtr(global_checkpoint), size, recv_link, req_in);
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief try to get the result of operation specified by seqno
|
* \brief try to get the result of operation specified by seqno
|
||||||
@ -607,11 +672,21 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
|
|||||||
* \sa ReturnType
|
* \sa ReturnType
|
||||||
*/
|
*/
|
||||||
AllreduceRobust::ReturnType
|
AllreduceRobust::ReturnType
|
||||||
AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) { RecoverType role;
|
AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) {
|
||||||
// if minimum sequence requested is local check point ack,
|
// if minimum sequence requested is local check point ack,
|
||||||
// this means all nodes have finished local check point, directly return
|
// this means all nodes have finished local check point, directly return
|
||||||
if (seqno == ActionSummary::kLocalCheckAck) return kSuccess;
|
if (seqno == ActionSummary::kLocalCheckAck) return kSuccess;
|
||||||
|
if (seqno == ActionSummary::kLocalCheckPoint) {
|
||||||
|
// new version of local model
|
||||||
|
int new_version = !local_chkpt_version;
|
||||||
|
int nlocal = std::max(static_cast<int>(local_rptr[new_version].size()) - 1, 0);
|
||||||
|
// if we goes to this place, use must have already setup the state once
|
||||||
|
utils::Assert(nlocal == 1 || nlocal == num_local_replica + 1,
|
||||||
|
"TryGetResult::Checkpoint");
|
||||||
|
return TryRecoverLocalState(&local_rptr[new_version], &local_chkpt[new_version]);
|
||||||
|
}
|
||||||
|
// handles normal data recovery
|
||||||
|
RecoverType role;
|
||||||
if (!requester) {
|
if (!requester) {
|
||||||
sendrecvbuf = resbuf.Query(seqno, &size);
|
sendrecvbuf = resbuf.Query(seqno, &size);
|
||||||
role = sendrecvbuf != NULL ? kHaveData : kPassData;
|
role = sendrecvbuf != NULL ? kHaveData : kPassData;
|
||||||
@ -786,7 +861,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
|||||||
}
|
}
|
||||||
chkpt.resize(rptr.back());
|
chkpt.resize(rptr.back());
|
||||||
// pass data through the link
|
// pass data through the link
|
||||||
succ = RingPassing(&chkpt[0], rptr[nlocal], rptr[nread_end],
|
succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end],
|
||||||
rptr[nwrite_start], rptr[nread_end],
|
rptr[nwrite_start], rptr[nread_end],
|
||||||
ring_next, ring_prev);
|
ring_next, ring_prev);
|
||||||
if (succ != kSuccess) {
|
if (succ != kSuccess) {
|
||||||
@ -849,7 +924,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
|||||||
}
|
}
|
||||||
chkpt.resize(rptr.back());
|
chkpt.resize(rptr.back());
|
||||||
// pass data through the link
|
// pass data through the link
|
||||||
succ = RingPassing(&chkpt[0], rptr[nlocal], rptr[nread_end],
|
succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end],
|
||||||
rptr[nwrite_start], rptr[nwrite_end],
|
rptr[nwrite_start], rptr[nwrite_end],
|
||||||
ring_prev, ring_next);
|
ring_prev, ring_next);
|
||||||
if (succ != kSuccess) {
|
if (succ != kSuccess) {
|
||||||
@ -858,6 +933,57 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
|||||||
}
|
}
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
/*!
|
||||||
|
* \brief try to checkpoint local state, this function is called in normal executation phase
|
||||||
|
* of checkpoint that contains local state
|
||||||
|
* the input state must exactly one saved state(local state of current node),
|
||||||
|
* after complete, this function will get local state from previous num_local_replica nodes and put them
|
||||||
|
* into local_chkpt and local_rptr
|
||||||
|
*
|
||||||
|
* It is also OK to call TryRecoverLocalState instead,
|
||||||
|
* TryRecoverLocalState makes less assumption about the input, and requires more communications
|
||||||
|
*
|
||||||
|
* \param p_local_rptr the pointer to the segment pointers in the states array
|
||||||
|
* \param p_local_chkpt the pointer to the storage of local check points
|
||||||
|
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||||
|
* \sa ReturnType, TryRecoverLocalState
|
||||||
|
*/
|
||||||
|
AllreduceRobust::ReturnType
|
||||||
|
AllreduceRobust::TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
|
||||||
|
std::string *p_local_chkpt) {
|
||||||
|
// if there is no local replica, we can do nothing
|
||||||
|
if (num_local_replica == 0) return kSuccess;
|
||||||
|
std::vector<size_t> &rptr = *p_local_rptr;
|
||||||
|
std::string &chkpt = *p_local_chkpt;
|
||||||
|
utils::Assert(rptr.size() == 2, "TryCheckinLocalState must have exactly 1 state");
|
||||||
|
const int n = num_local_replica;
|
||||||
|
std::vector<size_t> sizes(n + 1);
|
||||||
|
sizes[0] = rptr[1] - rptr[0];
|
||||||
|
ReturnType succ;
|
||||||
|
// pass size through the link
|
||||||
|
succ = RingPassing(BeginPtr(sizes),
|
||||||
|
1 * sizeof(size_t),
|
||||||
|
(n + 1) * sizeof(size_t),
|
||||||
|
0 * sizeof(size_t),
|
||||||
|
n * sizeof(size_t),
|
||||||
|
ring_prev, ring_next);
|
||||||
|
if (succ != kSuccess) return succ;
|
||||||
|
// update rptr
|
||||||
|
rptr.resize(n + 1);
|
||||||
|
for (int i = 1; i < n; ++i) {
|
||||||
|
rptr[i + 1] = rptr[i] + sizes[i];
|
||||||
|
}
|
||||||
|
chkpt.resize(rptr.back());
|
||||||
|
// pass data through the link
|
||||||
|
succ = RingPassing(BeginPtr(chkpt),
|
||||||
|
rptr[1], rptr[n + 1],
|
||||||
|
rptr[0], rptr[n],
|
||||||
|
ring_prev, ring_next);
|
||||||
|
if (succ != kSuccess) {
|
||||||
|
rptr.resize(2); chkpt.resize(rptr.back()); return succ;
|
||||||
|
}
|
||||||
|
return kSuccess;
|
||||||
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief perform a ring passing to receive data from prev link, and sent data to next link
|
* \brief perform a ring passing to receive data from prev link, and sent data to next link
|
||||||
* this allows data to stream over a ring structure
|
* this allows data to stream over a ring structure
|
||||||
@ -883,7 +1009,7 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
|
|||||||
size_t write_end,
|
size_t write_end,
|
||||||
LinkRecord *read_link,
|
LinkRecord *read_link,
|
||||||
LinkRecord *write_link) {
|
LinkRecord *write_link) {
|
||||||
if (links.size() == 0 || read_end == 0) return kSuccess;
|
if (read_link == NULL || write_link == NULL || read_end == 0) return kSuccess;
|
||||||
utils::Assert(read_end <= write_end, "boundary check");
|
utils::Assert(read_end <= write_end, "boundary check");
|
||||||
utils::Assert(read_ptr <= read_end, "boundary check");
|
utils::Assert(read_ptr <= read_end, "boundary check");
|
||||||
utils::Assert(write_ptr <= write_end, "boundary check");
|
utils::Assert(write_ptr <= write_end, "boundary check");
|
||||||
|
|||||||
@ -372,6 +372,23 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
*/
|
*/
|
||||||
ReturnType TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
ReturnType TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
||||||
std::string *p_local_chkpt);
|
std::string *p_local_chkpt);
|
||||||
|
/*!
|
||||||
|
* \brief try to checkpoint local state, this function is called in normal executation phase
|
||||||
|
* of checkpoint that contains local state
|
||||||
|
o * the input state must exactly one saved state(local state of current node),
|
||||||
|
* after complete, this function will get local state from previous num_local_replica nodes and put them
|
||||||
|
* into local_chkpt and local_rptr
|
||||||
|
*
|
||||||
|
* It is also OK to call TryRecoverLocalState instead,
|
||||||
|
* TryRecoverLocalState makes less assumption about the input, and requires more communications
|
||||||
|
*
|
||||||
|
* \param p_local_rptr the pointer to the segment pointers in the states array
|
||||||
|
* \param p_local_chkpt the pointer to the storage of local check points
|
||||||
|
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||||
|
* \sa ReturnType, TryRecoverLocalState
|
||||||
|
*/
|
||||||
|
ReturnType TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
|
||||||
|
std::string *p_local_chkpt);
|
||||||
/*!
|
/*!
|
||||||
* \brief perform a ring passing to receive data from prev link, and sent data to next link
|
* \brief perform a ring passing to receive data from prev link, and sent data to next link
|
||||||
* this allows data to stream over a ring structure
|
* this allows data to stream over a ring structure
|
||||||
@ -441,7 +458,7 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops in the ring
|
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops in the ring
|
||||||
std::vector<size_t> local_rptr[2];
|
std::vector<size_t> local_rptr[2];
|
||||||
// storage for local model replicas
|
// storage for local model replicas
|
||||||
std::string local_checkpoint[2];
|
std::string local_chkpt[2];
|
||||||
// version of local checkpoint can be 1 or 0
|
// version of local checkpoint can be 1 or 0
|
||||||
int local_chkpt_version;
|
int local_chkpt_version;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -63,25 +63,32 @@ class SlaveEntry:
|
|||||||
return job_map[self.jobid]
|
return job_map[self.jobid]
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
def get_neighbor(self, rank, nslave):
|
def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map):
|
||||||
rank = rank + 1
|
|
||||||
ret = []
|
|
||||||
if rank > 1:
|
|
||||||
ret.append(rank / 2 - 1)
|
|
||||||
if rank * 2 - 1 < nslave:
|
|
||||||
ret.append(rank * 2 - 1)
|
|
||||||
if rank * 2 < nslave:
|
|
||||||
ret.append(rank * 2)
|
|
||||||
return set(ret)
|
|
||||||
|
|
||||||
def assign_rank(self, rank, wait_conn, nslave):
|
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
nnset = self.get_neighbor(rank, nslave)
|
nnset = set(tree_map[rank])
|
||||||
|
rprev, rnext = ring_map[rank]
|
||||||
self.sock.sendint(rank)
|
self.sock.sendint(rank)
|
||||||
# send parent rank
|
# send parent rank
|
||||||
self.sock.sendint((rank + 1) / 2 - 1)
|
self.sock.sendint(parent_map[rank])
|
||||||
# send world size
|
# send world size
|
||||||
self.sock.sendint(nslave)
|
self.sock.sendint(len(tree_map))
|
||||||
|
self.sock.sendint(len(nnset))
|
||||||
|
# send the rprev and next link
|
||||||
|
for r in nnset:
|
||||||
|
self.sock.sendint(r)
|
||||||
|
# send prev link
|
||||||
|
if rprev != -1 and rprev != rank:
|
||||||
|
nnset.add(rprev)
|
||||||
|
self.sock.sendint(rprev)
|
||||||
|
else:
|
||||||
|
self.sock.sendint(-1)
|
||||||
|
# send next link
|
||||||
|
if rnext != -1 and rnext != rank:
|
||||||
|
nnset.add(rnext)
|
||||||
|
self.sock.sendint(rnext)
|
||||||
|
else:
|
||||||
|
self.sock.sendint(-1)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
ngood = self.sock.recvint()
|
ngood = self.sock.recvint()
|
||||||
goodset = set([])
|
goodset = set([])
|
||||||
@ -131,8 +138,35 @@ class Tracker:
|
|||||||
self.sock.close()
|
self.sock.close()
|
||||||
def slave_args(self):
|
def slave_args(self):
|
||||||
return ['rabit_tracker_uri=%s' % socket.gethostname(),
|
return ['rabit_tracker_uri=%s' % socket.gethostname(),
|
||||||
'rabit_tracker_port=%s' % self.port]
|
'rabit_tracker_port=%s' % self.port]
|
||||||
|
def get_neighbor(self, rank, nslave):
|
||||||
|
rank = rank + 1
|
||||||
|
ret = []
|
||||||
|
if rank > 1:
|
||||||
|
ret.append(rank / 2 - 1)
|
||||||
|
if rank * 2 - 1 < nslave:
|
||||||
|
ret.append(rank * 2 - 1)
|
||||||
|
if rank * 2 < nslave:
|
||||||
|
ret.append(rank * 2)
|
||||||
|
return ret
|
||||||
|
def get_tree(self, nslave):
|
||||||
|
tree_map = {}
|
||||||
|
parent_map = {}
|
||||||
|
for r in range(nslave):
|
||||||
|
tree_map[r] = self.get_neighbor(r, nslave)
|
||||||
|
parent_map[r] = (r + 1) / 2 - 1
|
||||||
|
return tree_map, parent_map
|
||||||
|
def get_ring(self, tree_map, parent_map):
|
||||||
|
ring_map = {}
|
||||||
|
nslave = len(tree_map)
|
||||||
|
for r in range(nslave):
|
||||||
|
rprev = (r + nslave - 1) % nslave
|
||||||
|
rnext = (r + 1) % nslave
|
||||||
|
ring_map[r] = (rprev, rnext)
|
||||||
|
return ring_map
|
||||||
def accept_slaves(self, nslave):
|
def accept_slaves(self, nslave):
|
||||||
|
tree_map, parent_map = self.get_tree(nslave)
|
||||||
|
ring_map = self.get_ring(tree_map, parent_map)
|
||||||
# set of nodes that finishs the job
|
# set of nodes that finishs the job
|
||||||
shutdown = {}
|
shutdown = {}
|
||||||
# set of nodes that is waiting for connections
|
# set of nodes that is waiting for connections
|
||||||
@ -163,7 +197,7 @@ class Tracker:
|
|||||||
rank = todo_nodes.pop(0)
|
rank = todo_nodes.pop(0)
|
||||||
if s.jobid != 'NULL':
|
if s.jobid != 'NULL':
|
||||||
job_map[s.jobid] = rank
|
job_map[s.jobid] = rank
|
||||||
s.assign_rank(rank, wait_conn, nslave)
|
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||||
if s.wait_accept > 0:
|
if s.wait_accept > 0:
|
||||||
wait_conn[rank] = s
|
wait_conn[rank] = s
|
||||||
print 'All nodes finishes job'
|
print 'All nodes finishes job'
|
||||||
|
|||||||
@ -153,7 +153,8 @@ class Socket {
|
|||||||
* \param end_port ending port number to try
|
* \param end_port ending port number to try
|
||||||
* \return the port successfully bind to, return -1 if failed to bind any port
|
* \return the port successfully bind to, return -1 if failed to bind any port
|
||||||
*/
|
*/
|
||||||
inline int TryBindHost(int start_port, int end_port) {
|
inline int TryBindHost(int start_port, int end_port) {
|
||||||
|
// TODO, add prefix check
|
||||||
for (int port = start_port; port < end_port; ++port) {
|
for (int port = start_port; port < end_port; ++port) {
|
||||||
SockAddr addr("0.0.0.0", port);
|
SockAddr addr("0.0.0.0", port);
|
||||||
if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0) {
|
if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0) {
|
||||||
|
|||||||
@ -187,5 +187,13 @@ inline const T *BeginPtr(const std::vector<T> &vec) {
|
|||||||
return &vec[0];
|
return &vec[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
inline char* BeginPtr(std::string &str) {
|
||||||
|
if (str.length() == 0) return NULL;
|
||||||
|
return &str[0];
|
||||||
|
}
|
||||||
|
inline const char* BeginPtr(const std::string &str) {
|
||||||
|
if (str.length() == 0) return NULL;
|
||||||
|
return &str[0];
|
||||||
|
}
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
#endif // RABIT_UTILS_H_
|
#endif // RABIT_UTILS_H_
|
||||||
|
|||||||
@ -24,7 +24,10 @@ def mpi_submit(nslave, args):
|
|||||||
args arguments to launch each job
|
args arguments to launch each job
|
||||||
this usually includes the parameters of master_uri and parameters passed into submit
|
this usually includes the parameters of master_uri and parameters passed into submit
|
||||||
"""
|
"""
|
||||||
cmd = ' '.join(['mpirun -n %d --hostfile %s' % (nslave, args[0])] + args[1:])
|
if args[0] == 'local':
|
||||||
|
cmd = ' '.join(['mpirun -n %d' % (nslave)] + args[1:])
|
||||||
|
else:
|
||||||
|
cmd = ' '.join(['mpirun -n %d --hostfile %s' % (nslave, args[0])] + args[1:])
|
||||||
print cmd
|
print cmd
|
||||||
subprocess.check_call(cmd, shell = True)
|
subprocess.check_call(cmd, shell = True)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user