support larger cluster (#73)

* fix error in dmlc#57, clean up comments and naming

* include missing packages, disable recovery tests for now

* disable local_recover tests until we have a bug fix

* support larger cluster

* fix lint, merge with master
This commit is contained in:
Chen Qin 2018-10-22 10:13:45 -07:00 committed by Nan Zhu
parent 69cdfae22f
commit 3a35dabfae
5 changed files with 121 additions and 123 deletions

View File

@ -208,9 +208,9 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
} else { } else {
fprintf(stderr, "retry connect to ip(retry time %d): [%s]\n", retry, tracker_uri.c_str()); fprintf(stderr, "retry connect to ip(retry time %d): [%s]\n", retry, tracker_uri.c_str());
#ifdef _MSC_VER #ifdef _MSC_VER
Sleep(1); Sleep(retry << 1);
#else #else
sleep(1); sleep(retry << 1);
#endif #endif
continue; continue;
} }
@ -454,29 +454,29 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
while (true) { while (true) {
// select helper // select helper
bool finished = true; bool finished = true;
utils::SelectHelper selecter; utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (i == parent_index) { if (i == parent_index) {
if (size_down_in != total_size) { if (size_down_in != total_size) {
selecter.WatchRead(links[i].sock); watcher.WatchRead(links[i].sock);
// only watch for exception in live channels // only watch for exception in live channels
selecter.WatchException(links[i].sock); watcher.WatchException(links[i].sock);
finished = false; finished = false;
} }
if (size_up_out != total_size && size_up_out < size_up_reduce) { if (size_up_out != total_size && size_up_out < size_up_reduce) {
selecter.WatchWrite(links[i].sock); watcher.WatchWrite(links[i].sock);
} }
} else { } else {
if (links[i].size_read != total_size) { if (links[i].size_read != total_size) {
selecter.WatchRead(links[i].sock); watcher.WatchRead(links[i].sock);
} }
// size_write <= size_read // size_write <= size_read
if (links[i].size_write != total_size) { if (links[i].size_write != total_size) {
if (links[i].size_write < size_down_in) { if (links[i].size_write < size_down_in) {
selecter.WatchWrite(links[i].sock); watcher.WatchWrite(links[i].sock);
} }
// only watch for exception in live channels // only watch for exception in live channels
selecter.WatchException(links[i].sock); watcher.WatchException(links[i].sock);
finished = false; finished = false;
} }
} }
@ -484,17 +484,17 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
// finish runing allreduce // finish runing allreduce
if (finished) break; if (finished) break;
// select must return // select must return
selecter.Select(); watcher.Poll();
// exception handling // exception handling
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link // recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) { if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept); return ReportError(&links[i], kGetExcept);
} }
} }
// read data from childs // read data from childs
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (i != parent_index && selecter.CheckRead(links[i].sock)) { if (i != parent_index && watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size); ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size);
if (ret != kSuccess) { if (ret != kSuccess) {
return ReportError(&links[i], ret); return ReportError(&links[i], ret);
@ -551,7 +551,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
} }
} }
// read data from parent // read data from parent
if (selecter.CheckRead(links[parent_index].sock) && if (watcher.CheckRead(links[parent_index].sock) &&
total_size > size_down_in) { total_size > size_down_in) {
ssize_t len = links[parent_index].sock. ssize_t len = links[parent_index].sock.
Recv(sendrecvbuf + size_down_in, total_size - size_down_in); Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
@ -620,37 +620,37 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
while (true) { while (true) {
bool finished = true; bool finished = true;
// select helper // select helper
utils::SelectHelper selecter; utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (in_link == -2) { if (in_link == -2) {
selecter.WatchRead(links[i].sock); finished = false; watcher.WatchRead(links[i].sock); finished = false;
} }
if (i == in_link && links[i].size_read != total_size) { if (i == in_link && links[i].size_read != total_size) {
selecter.WatchRead(links[i].sock); finished = false; watcher.WatchRead(links[i].sock); finished = false;
} }
if (in_link != -2 && i != in_link && links[i].size_write != total_size) { if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
if (links[i].size_write < size_in) { if (links[i].size_write < size_in) {
selecter.WatchWrite(links[i].sock); watcher.WatchWrite(links[i].sock);
} }
finished = false; finished = false;
} }
selecter.WatchException(links[i].sock); watcher.WatchException(links[i].sock);
} }
// finish running // finish running
if (finished) break; if (finished) break;
// select // select
selecter.Select(); watcher.Poll();
// exception handling // exception handling
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link // recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) { if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept); return ReportError(&links[i], kGetExcept);
} }
} }
if (in_link == -2) { if (in_link == -2) {
// probe in-link // probe in-link
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (selecter.CheckRead(links[i].sock)) { if (watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToArray(sendrecvbuf_, total_size); ReturnType ret = links[i].ReadToArray(sendrecvbuf_, total_size);
if (ret != kSuccess) { if (ret != kSuccess) {
return ReportError(&links[i], ret); return ReportError(&links[i], ret);
@ -663,7 +663,7 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
} }
} else { } else {
// read from in link // read from in link
if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) { if (in_link >= 0 && watcher.CheckRead(links[in_link].sock)) {
ReturnType ret = links[in_link].ReadToArray(sendrecvbuf_, total_size); ReturnType ret = links[in_link].ReadToArray(sendrecvbuf_, total_size);
if (ret != kSuccess) { if (ret != kSuccess) {
return ReportError(&links[in_link], ret); return ReportError(&links[in_link], ret);
@ -717,20 +717,20 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
while (true) { while (true) {
// select helper // select helper
bool finished = true; bool finished = true;
utils::SelectHelper selecter; utils::PollHelper watcher;
if (read_ptr != stop_read) { if (read_ptr != stop_read) {
selecter.WatchRead(next.sock); watcher.WatchRead(next.sock);
finished = false; finished = false;
} }
if (write_ptr != stop_write) { if (write_ptr != stop_write) {
if (write_ptr < read_ptr) { if (write_ptr < read_ptr) {
selecter.WatchWrite(prev.sock); watcher.WatchWrite(prev.sock);
} }
finished = false; finished = false;
} }
if (finished) break; if (finished) break;
selecter.Select(); watcher.Poll();
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) { if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
size_t size = stop_read - read_ptr; size_t size = stop_read - read_ptr;
size_t start = read_ptr % total_size; size_t start = read_ptr % total_size;
if (start + size > total_size) { if (start + size > total_size) {
@ -811,20 +811,20 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
while (true) { while (true) {
// select helper // select helper
bool finished = true; bool finished = true;
utils::SelectHelper selecter; utils::PollHelper watcher;
if (read_ptr != stop_read) { if (read_ptr != stop_read) {
selecter.WatchRead(next.sock); watcher.WatchRead(next.sock);
finished = false; finished = false;
} }
if (write_ptr != stop_write) { if (write_ptr != stop_write) {
if (write_ptr < reduce_ptr) { if (write_ptr < reduce_ptr) {
selecter.WatchWrite(prev.sock); watcher.WatchWrite(prev.sock);
} }
finished = false; finished = false;
} }
if (finished) break; if (finished) break;
selecter.Select(); watcher.Poll();
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) { if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read); ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
if (ret != kSuccess) { if (ret != kSuccess) {
return ReportError(&next, ret); return ReportError(&next, ret);

View File

@ -69,30 +69,30 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
if (parent_index == -1) { if (parent_index == -1) {
utils::Assert(stage != 2 && stage != 1, "invalie stage id"); utils::Assert(stage != 2 && stage != 1, "invalie stage id");
} }
// select helper // poll helper
utils::SelectHelper selecter; utils::PollHelper watcher;
bool done = (stage == 3); bool done = (stage == 3);
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
selecter.WatchException(links[i].sock); watcher.WatchException(links[i].sock);
switch (stage) { switch (stage) {
case 0: case 0:
if (i != parent_index && links[i].size_read != sizeof(EdgeType)) { if (i != parent_index && links[i].size_read != sizeof(EdgeType)) {
selecter.WatchRead(links[i].sock); watcher.WatchRead(links[i].sock);
} }
break; break;
case 1: case 1:
if (i == parent_index) { if (i == parent_index) {
selecter.WatchWrite(links[i].sock); watcher.WatchWrite(links[i].sock);
} }
break; break;
case 2: case 2:
if (i == parent_index) { if (i == parent_index) {
selecter.WatchRead(links[i].sock); watcher.WatchRead(links[i].sock);
} }
break; break;
case 3: case 3:
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) { if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
selecter.WatchWrite(links[i].sock); watcher.WatchWrite(links[i].sock);
done = false; done = false;
} }
break; break;
@ -101,11 +101,11 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
} }
// finish all the stages, and write out message // finish all the stages, and write out message
if (done) break; if (done) break;
selecter.Select(); watcher.Poll();
// exception handling // exception handling
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link // recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) { if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept); return ReportError(&links[i], kGetExcept);
} }
} }
@ -114,7 +114,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
// read data from childs // read data from childs
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (i != parent_index) { if (i != parent_index) {
if (selecter.CheckRead(links[i].sock)) { if (watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToArray(&edge_in[i], sizeof(EdgeType)); ReturnType ret = links[i].ReadToArray(&edge_in[i], sizeof(EdgeType));
if (ret != kSuccess) return ReportError(&links[i], ret); if (ret != kSuccess) return ReportError(&links[i], ret);
} }

View File

@ -334,7 +334,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
if (len == sizeof(sig)) all_links[i].size_write = 2; if (len == sizeof(sig)) all_links[i].size_write = 2;
} }
} }
utils::SelectHelper rsel; utils::PollHelper rsel;
bool finished = true; bool finished = true;
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (all_links[i].size_write != 2 && !all_links[i].sock.BadSocket()) { if (all_links[i].size_write != 2 && !all_links[i].sock.BadSocket()) {
@ -343,15 +343,15 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
} }
if (finished) break; if (finished) break;
// wait to read from the channels to discard data // wait to read from the channels to discard data
rsel.Select(); rsel.Poll();
} }
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (!all_links[i].sock.BadSocket()) { if (!all_links[i].sock.BadSocket()) {
utils::SelectHelper::WaitExcept(all_links[i].sock); utils::PollHelper::WaitExcept(all_links[i].sock);
} }
} }
while (true) { while (true) {
utils::SelectHelper rsel; utils::PollHelper rsel;
bool finished = true; bool finished = true;
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (all_links[i].size_read == 0 && !all_links[i].sock.BadSocket()) { if (all_links[i].size_read == 0 && !all_links[i].sock.BadSocket()) {
@ -359,7 +359,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
} }
} }
if (finished) break; if (finished) break;
rsel.Select(); rsel.Poll();
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (all_links[i].sock.BadSocket()) continue; if (all_links[i].sock.BadSocket()) continue;
if (all_links[i].size_read == 0) { if (all_links[i].size_read == 0) {
@ -624,32 +624,32 @@ AllreduceRobust::TryRecoverData(RecoverType role,
} }
while (true) { while (true) {
bool finished = true; bool finished = true;
utils::SelectHelper selecter; utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (i == recv_link && links[i].size_read != size) { if (i == recv_link && links[i].size_read != size) {
selecter.WatchRead(links[i].sock); watcher.WatchRead(links[i].sock);
finished = false; finished = false;
} }
if (req_in[i] && links[i].size_write != size) { if (req_in[i] && links[i].size_write != size) {
if (role == kHaveData || if (role == kHaveData ||
(links[recv_link].size_read != links[i].size_write)) { (links[recv_link].size_read != links[i].size_write)) {
selecter.WatchWrite(links[i].sock); watcher.WatchWrite(links[i].sock);
} }
finished = false; finished = false;
} }
selecter.WatchException(links[i].sock); watcher.WatchException(links[i].sock);
} }
if (finished) break; if (finished) break;
selecter.Select(); watcher.Poll();
// exception handling // exception handling
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (selecter.CheckExcept(links[i].sock)) { if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept); return ReportError(&links[i], kGetExcept);
} }
} }
if (role == kRequestData) { if (role == kRequestData) {
const int pid = recv_link; const int pid = recv_link;
if (selecter.CheckRead(links[pid].sock)) { if (watcher.CheckRead(links[pid].sock)) {
ReturnType ret = links[pid].ReadToArray(sendrecvbuf_, size); ReturnType ret = links[pid].ReadToArray(sendrecvbuf_, size);
if (ret != kSuccess) { if (ret != kSuccess) {
return ReportError(&links[pid], ret); return ReportError(&links[pid], ret);
@ -677,7 +677,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
if (role == kPassData) { if (role == kPassData) {
const int pid = recv_link; const int pid = recv_link;
const size_t buffer_size = links[pid].buffer_size; const size_t buffer_size = links[pid].buffer_size;
if (selecter.CheckRead(links[pid].sock)) { if (watcher.CheckRead(links[pid].sock)) {
size_t min_write = size; size_t min_write = size;
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (req_in[i]) min_write = std::min(links[i].size_write, min_write); if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
@ -1144,22 +1144,22 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
char *buf = reinterpret_cast<char*>(sendrecvbuf_); char *buf = reinterpret_cast<char*>(sendrecvbuf_);
while (true) { while (true) {
bool finished = true; bool finished = true;
utils::SelectHelper selecter; utils::PollHelper watcher;
if (read_ptr != read_end) { if (read_ptr != read_end) {
selecter.WatchRead(prev.sock); watcher.WatchRead(prev.sock);
finished = false; finished = false;
} }
if (write_ptr < read_ptr && write_ptr != write_end) { if (write_ptr < read_ptr && write_ptr != write_end) {
selecter.WatchWrite(next.sock); watcher.WatchWrite(next.sock);
finished = false; finished = false;
} }
selecter.WatchException(prev.sock); watcher.WatchException(prev.sock);
selecter.WatchException(next.sock); watcher.WatchException(next.sock);
if (finished) break; if (finished) break;
selecter.Select(); watcher.Poll();
if (selecter.CheckExcept(prev.sock)) return ReportError(&prev, kGetExcept); if (watcher.CheckExcept(prev.sock)) return ReportError(&prev, kGetExcept);
if (selecter.CheckExcept(next.sock)) return ReportError(&next, kGetExcept); if (watcher.CheckExcept(next.sock)) return ReportError(&next, kGetExcept);
if (read_ptr != read_end && selecter.CheckRead(prev.sock)) { if (read_ptr != read_end && watcher.CheckRead(prev.sock)) {
ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr); ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr);
if (len == 0) { if (len == 0) {
prev.sock.Close(); return ReportError(&prev, kRecvZeroLen); prev.sock.Close(); return ReportError(&prev, kRecvZeroLen);

View File

@ -20,17 +20,22 @@
#include <arpa/inet.h> #include <arpa/inet.h>
#include <netinet/in.h> #include <netinet/in.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <sys/select.h>
#include <sys/ioctl.h> #include <sys/ioctl.h>
#endif #endif
#include <string> #include <string>
#include <cstring> #include <cstring>
#include <vector>
#include <unordered_map>
#include "../include/rabit/internal/utils.h" #include "../include/rabit/internal/utils.h"
#if defined(_WIN32) #if defined(_WIN32)
typedef int ssize_t; typedef int ssize_t;
typedef int sock_size_t; typedef int sock_size_t;
static inline int poll(struct pollfd *pfd, int nfds,
int timeout) { return WSAPoll ( pfd, nfds, timeout ); }
#else #else
#include <sys/poll.h>
typedef int SOCKET; typedef int SOCKET;
typedef size_t sock_size_t; typedef size_t sock_size_t;
const int INVALID_SOCKET = -1; const int INVALID_SOCKET = -1;
@ -78,7 +83,7 @@ struct SockAddr {
std::string buf; buf.resize(256); std::string buf; buf.resize(256);
#ifdef _WIN32 #ifdef _WIN32
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr, const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
&buf[0], buf.length()); &buf[0], buf.length());
#else #else
const char *s = inet_ntop(AF_INET, &addr.sin_addr, const char *s = inet_ntop(AF_INET, &addr.sin_addr,
&buf[0], buf.length()); &buf[0], buf.length());
@ -126,11 +131,11 @@ class Socket {
#ifdef _WIN32 #ifdef _WIN32
WSADATA wsa_data; WSADATA wsa_data;
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) { if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
Socket::Error("Startup"); Socket::Error("Startup");
} }
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) { if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
WSACleanup(); WSACleanup();
utils::Error("Could not find a usable version of Winsock.dll\n"); utils::Error("Could not find a usable version of Winsock.dll\n");
} }
#endif #endif
} }
@ -209,7 +214,8 @@ class Socket {
inline int GetSockError(void) const { inline int GetSockError(void) const {
int error = 0; int error = 0;
socklen_t len = sizeof(error); socklen_t len = sizeof(error);
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&error), &len) != 0) { if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
reinterpret_cast<char*>(&error), &len) != 0) {
Error("GetSockError"); Error("GetSockError");
} }
return error; return error;
@ -419,109 +425,100 @@ class TCPSocket : public Socket{
} }
}; };
/*! \brief helper data structure to perform select */ /*! \brief helper data structure to perform poll */
struct SelectHelper { struct PollHelper {
public: public:
SelectHelper(void) {
FD_ZERO(&read_set);
FD_ZERO(&write_set);
FD_ZERO(&except_set);
maxfd = 0;
}
/*! /*!
* \brief add file descriptor to watch for read * \brief add file descriptor to watch for read
* \param fd file descriptor to be watched * \param fd file descriptor to be watched
*/ */
inline void WatchRead(SOCKET fd) { inline void WatchRead(SOCKET fd) {
FD_SET(fd, &read_set); auto& pfd = fds[fd];
if (fd > maxfd) maxfd = fd; pfd.fd = fd;
pfd.events |= POLLIN;
} }
/*! /*!
* \brief add file descriptor to watch for write * \brief add file descriptor to watch for write
* \param fd file descriptor to be watched * \param fd file descriptor to be watched
*/ */
inline void WatchWrite(SOCKET fd) { inline void WatchWrite(SOCKET fd) {
FD_SET(fd, &write_set); auto& pfd = fds[fd];
if (fd > maxfd) maxfd = fd; pfd.fd = fd;
pfd.events |= POLLOUT;
} }
/*! /*!
* \brief add file descriptor to watch for exception * \brief add file descriptor to watch for exception
* \param fd file descriptor to be watched * \param fd file descriptor to be watched
*/ */
inline void WatchException(SOCKET fd) { inline void WatchException(SOCKET fd) {
FD_SET(fd, &except_set); auto& pfd = fds[fd];
if (fd > maxfd) maxfd = fd; pfd.fd = fd;
pfd.events |= POLLPRI;
} }
/*! /*!
* \brief Check if the descriptor is ready for read * \brief Check if the descriptor is ready for read
* \param fd file descriptor to check status * \param fd file descriptor to check status
*/ */
inline bool CheckRead(SOCKET fd) const { inline bool CheckRead(SOCKET fd) const {
return FD_ISSET(fd, &read_set) != 0; const auto& pfd = fds.find(fd);
return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
} }
/*! /*!
* \brief Check if the descriptor is ready for write * \brief Check if the descriptor is ready for write
* \param fd file descriptor to check status * \param fd file descriptor to check status
*/ */
inline bool CheckWrite(SOCKET fd) const { inline bool CheckWrite(SOCKET fd) const {
return FD_ISSET(fd, &write_set) != 0; const auto& pfd = fds.find(fd);
return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
} }
/*! /*!
* \brief Check if the descriptor has any exception * \brief Check if the descriptor has any exception
* \param fd file descriptor to check status * \param fd file descriptor to check status
*/ */
inline bool CheckExcept(SOCKET fd) const { inline bool CheckExcept(SOCKET fd) const {
return FD_ISSET(fd, &except_set) != 0; const auto& pfd = fds.find(fd);
return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0);
} }
/*! /*!
* \brief wait for exception event on a single descriptor * \brief wait for exception event on a single descriptor
* \param fd the file descriptor to wait the event for * \param fd the file descriptor to wait the event for
* \param timeout the timeout counter, can be 0, which means wait until the event happen * \param timeout the timeout counter, can be negative, which means wait until the event happen
* \return 1 if success, 0 if timeout, and -1 if error occurs * \return 1 if success, 0 if timeout, and -1 if error occurs
*/ */
inline static int WaitExcept(SOCKET fd, long timeout = 0) { // NOLINT(*) inline static int WaitExcept(SOCKET fd, long timeout = -1) { // NOLINT(*)
fd_set wait_set; pollfd pfd;
FD_ZERO(&wait_set); pfd.fd = fd;
FD_SET(fd, &wait_set); pfd.events = POLLPRI;
return Select_(static_cast<int>(fd + 1), return poll(&pfd, 1, timeout);
NULL, NULL, &wait_set, timeout);
} }
/*! /*!
* \brief peform select on the set defined * \brief peform poll on the set defined, read, write, exception
* \param select_read whether to watch for read event * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
* \param select_write whether to watch for write event * \return
* \param select_except whether to watch for exception event
* \param timeout specify timeout in micro-seconds(ms) if equals 0, means select will always block
* \return number of active descriptors selected,
* return -1 if error occurs
*/ */
inline int Select(long timeout = 0) { // NOLINT(*) inline void Poll(long timeout = -1) { // NOLINT(*)
int ret = Select_(static_cast<int>(maxfd + 1), std::vector<pollfd> fdset;
&read_set, &write_set, &except_set, timeout); fdset.reserve(fds.size());
for (auto kv : fds) {
fdset.push_back(kv.second);
}
int ret = poll(fdset.data(), fdset.size(), timeout);
if (ret == -1) { if (ret == -1) {
Socket::Error("Select"); Socket::Error("Poll");
}
return ret;
}
private:
inline static int Select_(int maxfd, fd_set *rfds,
fd_set *wfds, fd_set *efds, long timeout) { // NOLINT(*)
#if !defined(_WIN32)
utils::Assert(maxfd < FD_SETSIZE, "maxdf must be smaller than FDSETSIZE");
#endif
if (timeout == 0) {
return select(maxfd, rfds, wfds, efds, NULL);
} else { } else {
timeval tm; for (auto& pfd : fdset) {
tm.tv_usec = (timeout % 1000) * 1000; auto revents = pfd.revents & pfd.events;
tm.tv_sec = timeout / 1000; if (!revents) {
return select(maxfd, rfds, wfds, efds, &tm); fds.erase(pfd.fd);
} else {
fds[pfd.fd].events = revents;
}
}
} }
} }
SOCKET maxfd; std::unordered_map<SOCKET, pollfd> fds;
fd_set read_set, write_set, except_set;
}; };
} // namespace utils } // namespace utils
} // namespace rabit } // namespace rabit

View File

@ -4,6 +4,7 @@
all: model_recover_10_10k model_recover_10_10k_die_same model_recover_10_10k_die_hard local_recover_10_10k all: model_recover_10_10k model_recover_10_10k_die_same model_recover_10_10k_die_hard local_recover_10_10k
# this experiment test recovery with actually process exit, use keepalive to keep program alive # this experiment test recovery with actually process exit, use keepalive to keep program alive
# TODO: enable those tests once we fix issue in rabit
model_recover_10_10k: model_recover_10_10k:
../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0