support larger cluster (#73)
* fix error in dmlc#57, clean up comments and naming * include missing packages, disable recovery tests for now * disable local_recover tests until we have a bug fix * support larger cluster * fix lint, merge with master
This commit is contained in:
parent
69cdfae22f
commit
3a35dabfae
@ -208,9 +208,9 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
|
|||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "retry connect to ip(retry time %d): [%s]\n", retry, tracker_uri.c_str());
|
fprintf(stderr, "retry connect to ip(retry time %d): [%s]\n", retry, tracker_uri.c_str());
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
Sleep(1);
|
Sleep(retry << 1);
|
||||||
#else
|
#else
|
||||||
sleep(1);
|
sleep(retry << 1);
|
||||||
#endif
|
#endif
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -454,29 +454,29 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
|||||||
while (true) {
|
while (true) {
|
||||||
// select helper
|
// select helper
|
||||||
bool finished = true;
|
bool finished = true;
|
||||||
utils::SelectHelper selecter;
|
utils::PollHelper watcher;
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (i == parent_index) {
|
if (i == parent_index) {
|
||||||
if (size_down_in != total_size) {
|
if (size_down_in != total_size) {
|
||||||
selecter.WatchRead(links[i].sock);
|
watcher.WatchRead(links[i].sock);
|
||||||
// only watch for exception in live channels
|
// only watch for exception in live channels
|
||||||
selecter.WatchException(links[i].sock);
|
watcher.WatchException(links[i].sock);
|
||||||
finished = false;
|
finished = false;
|
||||||
}
|
}
|
||||||
if (size_up_out != total_size && size_up_out < size_up_reduce) {
|
if (size_up_out != total_size && size_up_out < size_up_reduce) {
|
||||||
selecter.WatchWrite(links[i].sock);
|
watcher.WatchWrite(links[i].sock);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (links[i].size_read != total_size) {
|
if (links[i].size_read != total_size) {
|
||||||
selecter.WatchRead(links[i].sock);
|
watcher.WatchRead(links[i].sock);
|
||||||
}
|
}
|
||||||
// size_write <= size_read
|
// size_write <= size_read
|
||||||
if (links[i].size_write != total_size) {
|
if (links[i].size_write != total_size) {
|
||||||
if (links[i].size_write < size_down_in) {
|
if (links[i].size_write < size_down_in) {
|
||||||
selecter.WatchWrite(links[i].sock);
|
watcher.WatchWrite(links[i].sock);
|
||||||
}
|
}
|
||||||
// only watch for exception in live channels
|
// only watch for exception in live channels
|
||||||
selecter.WatchException(links[i].sock);
|
watcher.WatchException(links[i].sock);
|
||||||
finished = false;
|
finished = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -484,17 +484,17 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
|||||||
// finish runing allreduce
|
// finish runing allreduce
|
||||||
if (finished) break;
|
if (finished) break;
|
||||||
// select must return
|
// select must return
|
||||||
selecter.Select();
|
watcher.Poll();
|
||||||
// exception handling
|
// exception handling
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
// recive OOB message from some link
|
// recive OOB message from some link
|
||||||
if (selecter.CheckExcept(links[i].sock)) {
|
if (watcher.CheckExcept(links[i].sock)) {
|
||||||
return ReportError(&links[i], kGetExcept);
|
return ReportError(&links[i], kGetExcept);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// read data from childs
|
// read data from childs
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (i != parent_index && selecter.CheckRead(links[i].sock)) {
|
if (i != parent_index && watcher.CheckRead(links[i].sock)) {
|
||||||
ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size);
|
ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
return ReportError(&links[i], ret);
|
return ReportError(&links[i], ret);
|
||||||
@ -551,7 +551,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// read data from parent
|
// read data from parent
|
||||||
if (selecter.CheckRead(links[parent_index].sock) &&
|
if (watcher.CheckRead(links[parent_index].sock) &&
|
||||||
total_size > size_down_in) {
|
total_size > size_down_in) {
|
||||||
ssize_t len = links[parent_index].sock.
|
ssize_t len = links[parent_index].sock.
|
||||||
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
|
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
|
||||||
@ -620,37 +620,37 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
|||||||
while (true) {
|
while (true) {
|
||||||
bool finished = true;
|
bool finished = true;
|
||||||
// select helper
|
// select helper
|
||||||
utils::SelectHelper selecter;
|
utils::PollHelper watcher;
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (in_link == -2) {
|
if (in_link == -2) {
|
||||||
selecter.WatchRead(links[i].sock); finished = false;
|
watcher.WatchRead(links[i].sock); finished = false;
|
||||||
}
|
}
|
||||||
if (i == in_link && links[i].size_read != total_size) {
|
if (i == in_link && links[i].size_read != total_size) {
|
||||||
selecter.WatchRead(links[i].sock); finished = false;
|
watcher.WatchRead(links[i].sock); finished = false;
|
||||||
}
|
}
|
||||||
if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
|
if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
|
||||||
if (links[i].size_write < size_in) {
|
if (links[i].size_write < size_in) {
|
||||||
selecter.WatchWrite(links[i].sock);
|
watcher.WatchWrite(links[i].sock);
|
||||||
}
|
}
|
||||||
finished = false;
|
finished = false;
|
||||||
}
|
}
|
||||||
selecter.WatchException(links[i].sock);
|
watcher.WatchException(links[i].sock);
|
||||||
}
|
}
|
||||||
// finish running
|
// finish running
|
||||||
if (finished) break;
|
if (finished) break;
|
||||||
// select
|
// select
|
||||||
selecter.Select();
|
watcher.Poll();
|
||||||
// exception handling
|
// exception handling
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
// recive OOB message from some link
|
// recive OOB message from some link
|
||||||
if (selecter.CheckExcept(links[i].sock)) {
|
if (watcher.CheckExcept(links[i].sock)) {
|
||||||
return ReportError(&links[i], kGetExcept);
|
return ReportError(&links[i], kGetExcept);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (in_link == -2) {
|
if (in_link == -2) {
|
||||||
// probe in-link
|
// probe in-link
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (selecter.CheckRead(links[i].sock)) {
|
if (watcher.CheckRead(links[i].sock)) {
|
||||||
ReturnType ret = links[i].ReadToArray(sendrecvbuf_, total_size);
|
ReturnType ret = links[i].ReadToArray(sendrecvbuf_, total_size);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
return ReportError(&links[i], ret);
|
return ReportError(&links[i], ret);
|
||||||
@ -663,7 +663,7 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// read from in link
|
// read from in link
|
||||||
if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) {
|
if (in_link >= 0 && watcher.CheckRead(links[in_link].sock)) {
|
||||||
ReturnType ret = links[in_link].ReadToArray(sendrecvbuf_, total_size);
|
ReturnType ret = links[in_link].ReadToArray(sendrecvbuf_, total_size);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
return ReportError(&links[in_link], ret);
|
return ReportError(&links[in_link], ret);
|
||||||
@ -717,20 +717,20 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
|
|||||||
while (true) {
|
while (true) {
|
||||||
// select helper
|
// select helper
|
||||||
bool finished = true;
|
bool finished = true;
|
||||||
utils::SelectHelper selecter;
|
utils::PollHelper watcher;
|
||||||
if (read_ptr != stop_read) {
|
if (read_ptr != stop_read) {
|
||||||
selecter.WatchRead(next.sock);
|
watcher.WatchRead(next.sock);
|
||||||
finished = false;
|
finished = false;
|
||||||
}
|
}
|
||||||
if (write_ptr != stop_write) {
|
if (write_ptr != stop_write) {
|
||||||
if (write_ptr < read_ptr) {
|
if (write_ptr < read_ptr) {
|
||||||
selecter.WatchWrite(prev.sock);
|
watcher.WatchWrite(prev.sock);
|
||||||
}
|
}
|
||||||
finished = false;
|
finished = false;
|
||||||
}
|
}
|
||||||
if (finished) break;
|
if (finished) break;
|
||||||
selecter.Select();
|
watcher.Poll();
|
||||||
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
|
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
|
||||||
size_t size = stop_read - read_ptr;
|
size_t size = stop_read - read_ptr;
|
||||||
size_t start = read_ptr % total_size;
|
size_t start = read_ptr % total_size;
|
||||||
if (start + size > total_size) {
|
if (start + size > total_size) {
|
||||||
@ -811,20 +811,20 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
|
|||||||
while (true) {
|
while (true) {
|
||||||
// select helper
|
// select helper
|
||||||
bool finished = true;
|
bool finished = true;
|
||||||
utils::SelectHelper selecter;
|
utils::PollHelper watcher;
|
||||||
if (read_ptr != stop_read) {
|
if (read_ptr != stop_read) {
|
||||||
selecter.WatchRead(next.sock);
|
watcher.WatchRead(next.sock);
|
||||||
finished = false;
|
finished = false;
|
||||||
}
|
}
|
||||||
if (write_ptr != stop_write) {
|
if (write_ptr != stop_write) {
|
||||||
if (write_ptr < reduce_ptr) {
|
if (write_ptr < reduce_ptr) {
|
||||||
selecter.WatchWrite(prev.sock);
|
watcher.WatchWrite(prev.sock);
|
||||||
}
|
}
|
||||||
finished = false;
|
finished = false;
|
||||||
}
|
}
|
||||||
if (finished) break;
|
if (finished) break;
|
||||||
selecter.Select();
|
watcher.Poll();
|
||||||
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
|
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
|
||||||
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
|
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
return ReportError(&next, ret);
|
return ReportError(&next, ret);
|
||||||
|
|||||||
@ -69,30 +69,30 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
|
|||||||
if (parent_index == -1) {
|
if (parent_index == -1) {
|
||||||
utils::Assert(stage != 2 && stage != 1, "invalie stage id");
|
utils::Assert(stage != 2 && stage != 1, "invalie stage id");
|
||||||
}
|
}
|
||||||
// select helper
|
// poll helper
|
||||||
utils::SelectHelper selecter;
|
utils::PollHelper watcher;
|
||||||
bool done = (stage == 3);
|
bool done = (stage == 3);
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
selecter.WatchException(links[i].sock);
|
watcher.WatchException(links[i].sock);
|
||||||
switch (stage) {
|
switch (stage) {
|
||||||
case 0:
|
case 0:
|
||||||
if (i != parent_index && links[i].size_read != sizeof(EdgeType)) {
|
if (i != parent_index && links[i].size_read != sizeof(EdgeType)) {
|
||||||
selecter.WatchRead(links[i].sock);
|
watcher.WatchRead(links[i].sock);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case 1:
|
case 1:
|
||||||
if (i == parent_index) {
|
if (i == parent_index) {
|
||||||
selecter.WatchWrite(links[i].sock);
|
watcher.WatchWrite(links[i].sock);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
if (i == parent_index) {
|
if (i == parent_index) {
|
||||||
selecter.WatchRead(links[i].sock);
|
watcher.WatchRead(links[i].sock);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
|
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
|
||||||
selecter.WatchWrite(links[i].sock);
|
watcher.WatchWrite(links[i].sock);
|
||||||
done = false;
|
done = false;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
@ -101,11 +101,11 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
|
|||||||
}
|
}
|
||||||
// finish all the stages, and write out message
|
// finish all the stages, and write out message
|
||||||
if (done) break;
|
if (done) break;
|
||||||
selecter.Select();
|
watcher.Poll();
|
||||||
// exception handling
|
// exception handling
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
// recive OOB message from some link
|
// recive OOB message from some link
|
||||||
if (selecter.CheckExcept(links[i].sock)) {
|
if (watcher.CheckExcept(links[i].sock)) {
|
||||||
return ReportError(&links[i], kGetExcept);
|
return ReportError(&links[i], kGetExcept);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -114,7 +114,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
|
|||||||
// read data from childs
|
// read data from childs
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (i != parent_index) {
|
if (i != parent_index) {
|
||||||
if (selecter.CheckRead(links[i].sock)) {
|
if (watcher.CheckRead(links[i].sock)) {
|
||||||
ReturnType ret = links[i].ReadToArray(&edge_in[i], sizeof(EdgeType));
|
ReturnType ret = links[i].ReadToArray(&edge_in[i], sizeof(EdgeType));
|
||||||
if (ret != kSuccess) return ReportError(&links[i], ret);
|
if (ret != kSuccess) return ReportError(&links[i], ret);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -334,7 +334,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
|||||||
if (len == sizeof(sig)) all_links[i].size_write = 2;
|
if (len == sizeof(sig)) all_links[i].size_write = 2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
utils::SelectHelper rsel;
|
utils::PollHelper rsel;
|
||||||
bool finished = true;
|
bool finished = true;
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (all_links[i].size_write != 2 && !all_links[i].sock.BadSocket()) {
|
if (all_links[i].size_write != 2 && !all_links[i].sock.BadSocket()) {
|
||||||
@ -343,15 +343,15 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
|||||||
}
|
}
|
||||||
if (finished) break;
|
if (finished) break;
|
||||||
// wait to read from the channels to discard data
|
// wait to read from the channels to discard data
|
||||||
rsel.Select();
|
rsel.Poll();
|
||||||
}
|
}
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (!all_links[i].sock.BadSocket()) {
|
if (!all_links[i].sock.BadSocket()) {
|
||||||
utils::SelectHelper::WaitExcept(all_links[i].sock);
|
utils::PollHelper::WaitExcept(all_links[i].sock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
while (true) {
|
while (true) {
|
||||||
utils::SelectHelper rsel;
|
utils::PollHelper rsel;
|
||||||
bool finished = true;
|
bool finished = true;
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (all_links[i].size_read == 0 && !all_links[i].sock.BadSocket()) {
|
if (all_links[i].size_read == 0 && !all_links[i].sock.BadSocket()) {
|
||||||
@ -359,7 +359,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (finished) break;
|
if (finished) break;
|
||||||
rsel.Select();
|
rsel.Poll();
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (all_links[i].sock.BadSocket()) continue;
|
if (all_links[i].sock.BadSocket()) continue;
|
||||||
if (all_links[i].size_read == 0) {
|
if (all_links[i].size_read == 0) {
|
||||||
@ -624,32 +624,32 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
|||||||
}
|
}
|
||||||
while (true) {
|
while (true) {
|
||||||
bool finished = true;
|
bool finished = true;
|
||||||
utils::SelectHelper selecter;
|
utils::PollHelper watcher;
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (i == recv_link && links[i].size_read != size) {
|
if (i == recv_link && links[i].size_read != size) {
|
||||||
selecter.WatchRead(links[i].sock);
|
watcher.WatchRead(links[i].sock);
|
||||||
finished = false;
|
finished = false;
|
||||||
}
|
}
|
||||||
if (req_in[i] && links[i].size_write != size) {
|
if (req_in[i] && links[i].size_write != size) {
|
||||||
if (role == kHaveData ||
|
if (role == kHaveData ||
|
||||||
(links[recv_link].size_read != links[i].size_write)) {
|
(links[recv_link].size_read != links[i].size_write)) {
|
||||||
selecter.WatchWrite(links[i].sock);
|
watcher.WatchWrite(links[i].sock);
|
||||||
}
|
}
|
||||||
finished = false;
|
finished = false;
|
||||||
}
|
}
|
||||||
selecter.WatchException(links[i].sock);
|
watcher.WatchException(links[i].sock);
|
||||||
}
|
}
|
||||||
if (finished) break;
|
if (finished) break;
|
||||||
selecter.Select();
|
watcher.Poll();
|
||||||
// exception handling
|
// exception handling
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (selecter.CheckExcept(links[i].sock)) {
|
if (watcher.CheckExcept(links[i].sock)) {
|
||||||
return ReportError(&links[i], kGetExcept);
|
return ReportError(&links[i], kGetExcept);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (role == kRequestData) {
|
if (role == kRequestData) {
|
||||||
const int pid = recv_link;
|
const int pid = recv_link;
|
||||||
if (selecter.CheckRead(links[pid].sock)) {
|
if (watcher.CheckRead(links[pid].sock)) {
|
||||||
ReturnType ret = links[pid].ReadToArray(sendrecvbuf_, size);
|
ReturnType ret = links[pid].ReadToArray(sendrecvbuf_, size);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
return ReportError(&links[pid], ret);
|
return ReportError(&links[pid], ret);
|
||||||
@ -677,7 +677,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
|||||||
if (role == kPassData) {
|
if (role == kPassData) {
|
||||||
const int pid = recv_link;
|
const int pid = recv_link;
|
||||||
const size_t buffer_size = links[pid].buffer_size;
|
const size_t buffer_size = links[pid].buffer_size;
|
||||||
if (selecter.CheckRead(links[pid].sock)) {
|
if (watcher.CheckRead(links[pid].sock)) {
|
||||||
size_t min_write = size;
|
size_t min_write = size;
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
|
if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
|
||||||
@ -1144,22 +1144,22 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
|
|||||||
char *buf = reinterpret_cast<char*>(sendrecvbuf_);
|
char *buf = reinterpret_cast<char*>(sendrecvbuf_);
|
||||||
while (true) {
|
while (true) {
|
||||||
bool finished = true;
|
bool finished = true;
|
||||||
utils::SelectHelper selecter;
|
utils::PollHelper watcher;
|
||||||
if (read_ptr != read_end) {
|
if (read_ptr != read_end) {
|
||||||
selecter.WatchRead(prev.sock);
|
watcher.WatchRead(prev.sock);
|
||||||
finished = false;
|
finished = false;
|
||||||
}
|
}
|
||||||
if (write_ptr < read_ptr && write_ptr != write_end) {
|
if (write_ptr < read_ptr && write_ptr != write_end) {
|
||||||
selecter.WatchWrite(next.sock);
|
watcher.WatchWrite(next.sock);
|
||||||
finished = false;
|
finished = false;
|
||||||
}
|
}
|
||||||
selecter.WatchException(prev.sock);
|
watcher.WatchException(prev.sock);
|
||||||
selecter.WatchException(next.sock);
|
watcher.WatchException(next.sock);
|
||||||
if (finished) break;
|
if (finished) break;
|
||||||
selecter.Select();
|
watcher.Poll();
|
||||||
if (selecter.CheckExcept(prev.sock)) return ReportError(&prev, kGetExcept);
|
if (watcher.CheckExcept(prev.sock)) return ReportError(&prev, kGetExcept);
|
||||||
if (selecter.CheckExcept(next.sock)) return ReportError(&next, kGetExcept);
|
if (watcher.CheckExcept(next.sock)) return ReportError(&next, kGetExcept);
|
||||||
if (read_ptr != read_end && selecter.CheckRead(prev.sock)) {
|
if (read_ptr != read_end && watcher.CheckRead(prev.sock)) {
|
||||||
ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr);
|
ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr);
|
||||||
if (len == 0) {
|
if (len == 0) {
|
||||||
prev.sock.Close(); return ReportError(&prev, kRecvZeroLen);
|
prev.sock.Close(); return ReportError(&prev, kRecvZeroLen);
|
||||||
|
|||||||
107
src/socket.h
107
src/socket.h
@ -20,17 +20,22 @@
|
|||||||
#include <arpa/inet.h>
|
#include <arpa/inet.h>
|
||||||
#include <netinet/in.h>
|
#include <netinet/in.h>
|
||||||
#include <sys/socket.h>
|
#include <sys/socket.h>
|
||||||
#include <sys/select.h>
|
|
||||||
#include <sys/ioctl.h>
|
#include <sys/ioctl.h>
|
||||||
#endif
|
#endif
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
#include <vector>
|
||||||
|
#include <unordered_map>
|
||||||
#include "../include/rabit/internal/utils.h"
|
#include "../include/rabit/internal/utils.h"
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
typedef int ssize_t;
|
typedef int ssize_t;
|
||||||
typedef int sock_size_t;
|
typedef int sock_size_t;
|
||||||
|
|
||||||
|
static inline int poll(struct pollfd *pfd, int nfds,
|
||||||
|
int timeout) { return WSAPoll ( pfd, nfds, timeout ); }
|
||||||
#else
|
#else
|
||||||
|
#include <sys/poll.h>
|
||||||
typedef int SOCKET;
|
typedef int SOCKET;
|
||||||
typedef size_t sock_size_t;
|
typedef size_t sock_size_t;
|
||||||
const int INVALID_SOCKET = -1;
|
const int INVALID_SOCKET = -1;
|
||||||
@ -209,7 +214,8 @@ class Socket {
|
|||||||
inline int GetSockError(void) const {
|
inline int GetSockError(void) const {
|
||||||
int error = 0;
|
int error = 0;
|
||||||
socklen_t len = sizeof(error);
|
socklen_t len = sizeof(error);
|
||||||
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&error), &len) != 0) {
|
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
|
||||||
|
reinterpret_cast<char*>(&error), &len) != 0) {
|
||||||
Error("GetSockError");
|
Error("GetSockError");
|
||||||
}
|
}
|
||||||
return error;
|
return error;
|
||||||
@ -419,109 +425,100 @@ class TCPSocket : public Socket{
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/*! \brief helper data structure to perform select */
|
/*! \brief helper data structure to perform poll */
|
||||||
struct SelectHelper {
|
struct PollHelper {
|
||||||
public:
|
public:
|
||||||
SelectHelper(void) {
|
|
||||||
FD_ZERO(&read_set);
|
|
||||||
FD_ZERO(&write_set);
|
|
||||||
FD_ZERO(&except_set);
|
|
||||||
maxfd = 0;
|
|
||||||
}
|
|
||||||
/*!
|
/*!
|
||||||
* \brief add file descriptor to watch for read
|
* \brief add file descriptor to watch for read
|
||||||
* \param fd file descriptor to be watched
|
* \param fd file descriptor to be watched
|
||||||
*/
|
*/
|
||||||
inline void WatchRead(SOCKET fd) {
|
inline void WatchRead(SOCKET fd) {
|
||||||
FD_SET(fd, &read_set);
|
auto& pfd = fds[fd];
|
||||||
if (fd > maxfd) maxfd = fd;
|
pfd.fd = fd;
|
||||||
|
pfd.events |= POLLIN;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief add file descriptor to watch for write
|
* \brief add file descriptor to watch for write
|
||||||
* \param fd file descriptor to be watched
|
* \param fd file descriptor to be watched
|
||||||
*/
|
*/
|
||||||
inline void WatchWrite(SOCKET fd) {
|
inline void WatchWrite(SOCKET fd) {
|
||||||
FD_SET(fd, &write_set);
|
auto& pfd = fds[fd];
|
||||||
if (fd > maxfd) maxfd = fd;
|
pfd.fd = fd;
|
||||||
|
pfd.events |= POLLOUT;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief add file descriptor to watch for exception
|
* \brief add file descriptor to watch for exception
|
||||||
* \param fd file descriptor to be watched
|
* \param fd file descriptor to be watched
|
||||||
*/
|
*/
|
||||||
inline void WatchException(SOCKET fd) {
|
inline void WatchException(SOCKET fd) {
|
||||||
FD_SET(fd, &except_set);
|
auto& pfd = fds[fd];
|
||||||
if (fd > maxfd) maxfd = fd;
|
pfd.fd = fd;
|
||||||
|
pfd.events |= POLLPRI;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief Check if the descriptor is ready for read
|
* \brief Check if the descriptor is ready for read
|
||||||
* \param fd file descriptor to check status
|
* \param fd file descriptor to check status
|
||||||
*/
|
*/
|
||||||
inline bool CheckRead(SOCKET fd) const {
|
inline bool CheckRead(SOCKET fd) const {
|
||||||
return FD_ISSET(fd, &read_set) != 0;
|
const auto& pfd = fds.find(fd);
|
||||||
|
return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief Check if the descriptor is ready for write
|
* \brief Check if the descriptor is ready for write
|
||||||
* \param fd file descriptor to check status
|
* \param fd file descriptor to check status
|
||||||
*/
|
*/
|
||||||
inline bool CheckWrite(SOCKET fd) const {
|
inline bool CheckWrite(SOCKET fd) const {
|
||||||
return FD_ISSET(fd, &write_set) != 0;
|
const auto& pfd = fds.find(fd);
|
||||||
|
return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief Check if the descriptor has any exception
|
* \brief Check if the descriptor has any exception
|
||||||
* \param fd file descriptor to check status
|
* \param fd file descriptor to check status
|
||||||
*/
|
*/
|
||||||
inline bool CheckExcept(SOCKET fd) const {
|
inline bool CheckExcept(SOCKET fd) const {
|
||||||
return FD_ISSET(fd, &except_set) != 0;
|
const auto& pfd = fds.find(fd);
|
||||||
|
return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0);
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief wait for exception event on a single descriptor
|
* \brief wait for exception event on a single descriptor
|
||||||
* \param fd the file descriptor to wait the event for
|
* \param fd the file descriptor to wait the event for
|
||||||
* \param timeout the timeout counter, can be 0, which means wait until the event happen
|
* \param timeout the timeout counter, can be negative, which means wait until the event happen
|
||||||
* \return 1 if success, 0 if timeout, and -1 if error occurs
|
* \return 1 if success, 0 if timeout, and -1 if error occurs
|
||||||
*/
|
*/
|
||||||
inline static int WaitExcept(SOCKET fd, long timeout = 0) { // NOLINT(*)
|
inline static int WaitExcept(SOCKET fd, long timeout = -1) { // NOLINT(*)
|
||||||
fd_set wait_set;
|
pollfd pfd;
|
||||||
FD_ZERO(&wait_set);
|
pfd.fd = fd;
|
||||||
FD_SET(fd, &wait_set);
|
pfd.events = POLLPRI;
|
||||||
return Select_(static_cast<int>(fd + 1),
|
return poll(&pfd, 1, timeout);
|
||||||
NULL, NULL, &wait_set, timeout);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief peform select on the set defined
|
* \brief peform poll on the set defined, read, write, exception
|
||||||
* \param select_read whether to watch for read event
|
* \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
|
||||||
* \param select_write whether to watch for write event
|
* \return
|
||||||
* \param select_except whether to watch for exception event
|
|
||||||
* \param timeout specify timeout in micro-seconds(ms) if equals 0, means select will always block
|
|
||||||
* \return number of active descriptors selected,
|
|
||||||
* return -1 if error occurs
|
|
||||||
*/
|
*/
|
||||||
inline int Select(long timeout = 0) { // NOLINT(*)
|
inline void Poll(long timeout = -1) { // NOLINT(*)
|
||||||
int ret = Select_(static_cast<int>(maxfd + 1),
|
std::vector<pollfd> fdset;
|
||||||
&read_set, &write_set, &except_set, timeout);
|
fdset.reserve(fds.size());
|
||||||
|
for (auto kv : fds) {
|
||||||
|
fdset.push_back(kv.second);
|
||||||
|
}
|
||||||
|
int ret = poll(fdset.data(), fdset.size(), timeout);
|
||||||
if (ret == -1) {
|
if (ret == -1) {
|
||||||
Socket::Error("Select");
|
Socket::Error("Poll");
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
inline static int Select_(int maxfd, fd_set *rfds,
|
|
||||||
fd_set *wfds, fd_set *efds, long timeout) { // NOLINT(*)
|
|
||||||
#if !defined(_WIN32)
|
|
||||||
utils::Assert(maxfd < FD_SETSIZE, "maxdf must be smaller than FDSETSIZE");
|
|
||||||
#endif
|
|
||||||
if (timeout == 0) {
|
|
||||||
return select(maxfd, rfds, wfds, efds, NULL);
|
|
||||||
} else {
|
} else {
|
||||||
timeval tm;
|
for (auto& pfd : fdset) {
|
||||||
tm.tv_usec = (timeout % 1000) * 1000;
|
auto revents = pfd.revents & pfd.events;
|
||||||
tm.tv_sec = timeout / 1000;
|
if (!revents) {
|
||||||
return select(maxfd, rfds, wfds, efds, &tm);
|
fds.erase(pfd.fd);
|
||||||
|
} else {
|
||||||
|
fds[pfd.fd].events = revents;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SOCKET maxfd;
|
std::unordered_map<SOCKET, pollfd> fds;
|
||||||
fd_set read_set, write_set, except_set;
|
|
||||||
};
|
};
|
||||||
} // namespace utils
|
} // namespace utils
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
all: model_recover_10_10k model_recover_10_10k_die_same model_recover_10_10k_die_hard local_recover_10_10k
|
all: model_recover_10_10k model_recover_10_10k_die_same model_recover_10_10k_die_hard local_recover_10_10k
|
||||||
|
|
||||||
# this experiment test recovery with actually process exit, use keepalive to keep program alive
|
# this experiment test recovery with actually process exit, use keepalive to keep program alive
|
||||||
|
# TODO: enable those tests once we fix issue in rabit
|
||||||
model_recover_10_10k:
|
model_recover_10_10k:
|
||||||
../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0
|
../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user