support larger cluster (#73)

* fix error in dmlc#57, clean up comments and naming

* include missing packages, disable recovery tests for now

* disable local_recover tests until we have a bug fix

* support larger cluster

* fix lint, merge with master
This commit is contained in:
Chen Qin 2018-10-22 10:13:45 -07:00 committed by Nan Zhu
parent 69cdfae22f
commit 3a35dabfae
5 changed files with 121 additions and 123 deletions

View File

@ -208,9 +208,9 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
} else {
fprintf(stderr, "retry connect to ip(retry time %d): [%s]\n", retry, tracker_uri.c_str());
#ifdef _MSC_VER
Sleep(1);
Sleep(retry << 1);
#else
sleep(1);
sleep(retry << 1);
#endif
continue;
}
@ -454,29 +454,29 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) {
if (i == parent_index) {
if (size_down_in != total_size) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
// only watch for exception in live channels
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
finished = false;
}
if (size_up_out != total_size && size_up_out < size_up_reduce) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
} else {
if (links[i].size_read != total_size) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
}
// size_write <= size_read
if (links[i].size_write != total_size) {
if (links[i].size_write < size_down_in) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
// only watch for exception in live channels
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
finished = false;
}
}
@ -484,17 +484,17 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
// finish runing allreduce
if (finished) break;
// select must return
selecter.Select();
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) {
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
// read data from childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && selecter.CheckRead(links[i].sock)) {
if (i != parent_index && watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
@ -551,7 +551,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
}
}
// read data from parent
if (selecter.CheckRead(links[parent_index].sock) &&
if (watcher.CheckRead(links[parent_index].sock) &&
total_size > size_down_in) {
ssize_t len = links[parent_index].sock.
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
@ -620,37 +620,37 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
while (true) {
bool finished = true;
// select helper
utils::SelectHelper selecter;
utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) {
if (in_link == -2) {
selecter.WatchRead(links[i].sock); finished = false;
watcher.WatchRead(links[i].sock); finished = false;
}
if (i == in_link && links[i].size_read != total_size) {
selecter.WatchRead(links[i].sock); finished = false;
watcher.WatchRead(links[i].sock); finished = false;
}
if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
if (links[i].size_write < size_in) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
finished = false;
}
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
}
// finish running
if (finished) break;
// select
selecter.Select();
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) {
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
if (in_link == -2) {
// probe in-link
for (int i = 0; i < nlink; ++i) {
if (selecter.CheckRead(links[i].sock)) {
if (watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToArray(sendrecvbuf_, total_size);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
@ -663,7 +663,7 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
}
} else {
// read from in link
if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) {
if (in_link >= 0 && watcher.CheckRead(links[in_link].sock)) {
ReturnType ret = links[in_link].ReadToArray(sendrecvbuf_, total_size);
if (ret != kSuccess) {
return ReportError(&links[in_link], ret);
@ -717,20 +717,20 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
if (read_ptr != stop_read) {
selecter.WatchRead(next.sock);
watcher.WatchRead(next.sock);
finished = false;
}
if (write_ptr != stop_write) {
if (write_ptr < read_ptr) {
selecter.WatchWrite(prev.sock);
watcher.WatchWrite(prev.sock);
}
finished = false;
}
if (finished) break;
selecter.Select();
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
watcher.Poll();
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
size_t size = stop_read - read_ptr;
size_t start = read_ptr % total_size;
if (start + size > total_size) {
@ -811,20 +811,20 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
if (read_ptr != stop_read) {
selecter.WatchRead(next.sock);
watcher.WatchRead(next.sock);
finished = false;
}
if (write_ptr != stop_write) {
if (write_ptr < reduce_ptr) {
selecter.WatchWrite(prev.sock);
watcher.WatchWrite(prev.sock);
}
finished = false;
}
if (finished) break;
selecter.Select();
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
watcher.Poll();
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
if (ret != kSuccess) {
return ReportError(&next, ret);

View File

@ -69,30 +69,30 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
if (parent_index == -1) {
utils::Assert(stage != 2 && stage != 1, "invalie stage id");
}
// select helper
utils::SelectHelper selecter;
// poll helper
utils::PollHelper watcher;
bool done = (stage == 3);
for (int i = 0; i < nlink; ++i) {
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
switch (stage) {
case 0:
if (i != parent_index && links[i].size_read != sizeof(EdgeType)) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
}
break;
case 1:
if (i == parent_index) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
break;
case 2:
if (i == parent_index) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
}
break;
case 3:
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
done = false;
}
break;
@ -101,11 +101,11 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
}
// finish all the stages, and write out message
if (done) break;
selecter.Select();
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) {
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
@ -114,7 +114,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
// read data from childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
if (selecter.CheckRead(links[i].sock)) {
if (watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToArray(&edge_in[i], sizeof(EdgeType));
if (ret != kSuccess) return ReportError(&links[i], ret);
}

View File

@ -334,7 +334,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
if (len == sizeof(sig)) all_links[i].size_write = 2;
}
}
utils::SelectHelper rsel;
utils::PollHelper rsel;
bool finished = true;
for (int i = 0; i < nlink; ++i) {
if (all_links[i].size_write != 2 && !all_links[i].sock.BadSocket()) {
@ -343,15 +343,15 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
}
if (finished) break;
// wait to read from the channels to discard data
rsel.Select();
rsel.Poll();
}
for (int i = 0; i < nlink; ++i) {
if (!all_links[i].sock.BadSocket()) {
utils::SelectHelper::WaitExcept(all_links[i].sock);
utils::PollHelper::WaitExcept(all_links[i].sock);
}
}
while (true) {
utils::SelectHelper rsel;
utils::PollHelper rsel;
bool finished = true;
for (int i = 0; i < nlink; ++i) {
if (all_links[i].size_read == 0 && !all_links[i].sock.BadSocket()) {
@ -359,7 +359,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
}
}
if (finished) break;
rsel.Select();
rsel.Poll();
for (int i = 0; i < nlink; ++i) {
if (all_links[i].sock.BadSocket()) continue;
if (all_links[i].size_read == 0) {
@ -624,32 +624,32 @@ AllreduceRobust::TryRecoverData(RecoverType role,
}
while (true) {
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) {
if (i == recv_link && links[i].size_read != size) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
finished = false;
}
if (req_in[i] && links[i].size_write != size) {
if (role == kHaveData ||
(links[recv_link].size_read != links[i].size_write)) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
finished = false;
}
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
}
if (finished) break;
selecter.Select();
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
if (selecter.CheckExcept(links[i].sock)) {
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
if (role == kRequestData) {
const int pid = recv_link;
if (selecter.CheckRead(links[pid].sock)) {
if (watcher.CheckRead(links[pid].sock)) {
ReturnType ret = links[pid].ReadToArray(sendrecvbuf_, size);
if (ret != kSuccess) {
return ReportError(&links[pid], ret);
@ -677,7 +677,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
if (role == kPassData) {
const int pid = recv_link;
const size_t buffer_size = links[pid].buffer_size;
if (selecter.CheckRead(links[pid].sock)) {
if (watcher.CheckRead(links[pid].sock)) {
size_t min_write = size;
for (int i = 0; i < nlink; ++i) {
if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
@ -1144,22 +1144,22 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
char *buf = reinterpret_cast<char*>(sendrecvbuf_);
while (true) {
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
if (read_ptr != read_end) {
selecter.WatchRead(prev.sock);
watcher.WatchRead(prev.sock);
finished = false;
}
if (write_ptr < read_ptr && write_ptr != write_end) {
selecter.WatchWrite(next.sock);
watcher.WatchWrite(next.sock);
finished = false;
}
selecter.WatchException(prev.sock);
selecter.WatchException(next.sock);
watcher.WatchException(prev.sock);
watcher.WatchException(next.sock);
if (finished) break;
selecter.Select();
if (selecter.CheckExcept(prev.sock)) return ReportError(&prev, kGetExcept);
if (selecter.CheckExcept(next.sock)) return ReportError(&next, kGetExcept);
if (read_ptr != read_end && selecter.CheckRead(prev.sock)) {
watcher.Poll();
if (watcher.CheckExcept(prev.sock)) return ReportError(&prev, kGetExcept);
if (watcher.CheckExcept(next.sock)) return ReportError(&next, kGetExcept);
if (read_ptr != read_end && watcher.CheckRead(prev.sock)) {
ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr);
if (len == 0) {
prev.sock.Close(); return ReportError(&prev, kRecvZeroLen);

View File

@ -20,17 +20,22 @@
#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/select.h>
#include <sys/ioctl.h>
#endif
#include <string>
#include <cstring>
#include <vector>
#include <unordered_map>
#include "../include/rabit/internal/utils.h"
#if defined(_WIN32)
typedef int ssize_t;
typedef int sock_size_t;
static inline int poll(struct pollfd *pfd, int nfds,
int timeout) { return WSAPoll ( pfd, nfds, timeout ); }
#else
#include <sys/poll.h>
typedef int SOCKET;
typedef size_t sock_size_t;
const int INVALID_SOCKET = -1;
@ -78,7 +83,7 @@ struct SockAddr {
std::string buf; buf.resize(256);
#ifdef _WIN32
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
&buf[0], buf.length());
&buf[0], buf.length());
#else
const char *s = inet_ntop(AF_INET, &addr.sin_addr,
&buf[0], buf.length());
@ -126,11 +131,11 @@ class Socket {
#ifdef _WIN32
WSADATA wsa_data;
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
Socket::Error("Startup");
Socket::Error("Startup");
}
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
WSACleanup();
utils::Error("Could not find a usable version of Winsock.dll\n");
WSACleanup();
utils::Error("Could not find a usable version of Winsock.dll\n");
}
#endif
}
@ -209,7 +214,8 @@ class Socket {
inline int GetSockError(void) const {
int error = 0;
socklen_t len = sizeof(error);
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&error), &len) != 0) {
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
reinterpret_cast<char*>(&error), &len) != 0) {
Error("GetSockError");
}
return error;
@ -419,109 +425,100 @@ class TCPSocket : public Socket{
}
};
/*! \brief helper data structure to perform select */
struct SelectHelper {
/*! \brief helper data structure to perform poll */
struct PollHelper {
public:
SelectHelper(void) {
FD_ZERO(&read_set);
FD_ZERO(&write_set);
FD_ZERO(&except_set);
maxfd = 0;
}
/*!
* \brief add file descriptor to watch for read
* \param fd file descriptor to be watched
*/
inline void WatchRead(SOCKET fd) {
FD_SET(fd, &read_set);
if (fd > maxfd) maxfd = fd;
auto& pfd = fds[fd];
pfd.fd = fd;
pfd.events |= POLLIN;
}
/*!
* \brief add file descriptor to watch for write
* \param fd file descriptor to be watched
*/
inline void WatchWrite(SOCKET fd) {
FD_SET(fd, &write_set);
if (fd > maxfd) maxfd = fd;
auto& pfd = fds[fd];
pfd.fd = fd;
pfd.events |= POLLOUT;
}
/*!
* \brief add file descriptor to watch for exception
* \param fd file descriptor to be watched
*/
inline void WatchException(SOCKET fd) {
FD_SET(fd, &except_set);
if (fd > maxfd) maxfd = fd;
auto& pfd = fds[fd];
pfd.fd = fd;
pfd.events |= POLLPRI;
}
/*!
* \brief Check if the descriptor is ready for read
* \param fd file descriptor to check status
*/
inline bool CheckRead(SOCKET fd) const {
return FD_ISSET(fd, &read_set) != 0;
const auto& pfd = fds.find(fd);
return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
}
/*!
* \brief Check if the descriptor is ready for write
* \param fd file descriptor to check status
*/
inline bool CheckWrite(SOCKET fd) const {
return FD_ISSET(fd, &write_set) != 0;
const auto& pfd = fds.find(fd);
return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
}
/*!
* \brief Check if the descriptor has any exception
* \param fd file descriptor to check status
*/
inline bool CheckExcept(SOCKET fd) const {
return FD_ISSET(fd, &except_set) != 0;
const auto& pfd = fds.find(fd);
return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0);
}
/*!
* \brief wait for exception event on a single descriptor
* \param fd the file descriptor to wait the event for
* \param timeout the timeout counter, can be 0, which means wait until the event happen
* \param timeout the timeout counter, can be negative, which means wait until the event happen
* \return 1 if success, 0 if timeout, and -1 if error occurs
*/
inline static int WaitExcept(SOCKET fd, long timeout = 0) { // NOLINT(*)
fd_set wait_set;
FD_ZERO(&wait_set);
FD_SET(fd, &wait_set);
return Select_(static_cast<int>(fd + 1),
NULL, NULL, &wait_set, timeout);
inline static int WaitExcept(SOCKET fd, long timeout = -1) { // NOLINT(*)
pollfd pfd;
pfd.fd = fd;
pfd.events = POLLPRI;
return poll(&pfd, 1, timeout);
}
/*!
* \brief peform select on the set defined
* \param select_read whether to watch for read event
* \param select_write whether to watch for write event
* \param select_except whether to watch for exception event
* \param timeout specify timeout in micro-seconds(ms) if equals 0, means select will always block
* \return number of active descriptors selected,
* return -1 if error occurs
* \brief peform poll on the set defined, read, write, exception
* \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
* \return
*/
inline int Select(long timeout = 0) { // NOLINT(*)
int ret = Select_(static_cast<int>(maxfd + 1),
&read_set, &write_set, &except_set, timeout);
inline void Poll(long timeout = -1) { // NOLINT(*)
std::vector<pollfd> fdset;
fdset.reserve(fds.size());
for (auto kv : fds) {
fdset.push_back(kv.second);
}
int ret = poll(fdset.data(), fdset.size(), timeout);
if (ret == -1) {
Socket::Error("Select");
}
return ret;
}
private:
inline static int Select_(int maxfd, fd_set *rfds,
fd_set *wfds, fd_set *efds, long timeout) { // NOLINT(*)
#if !defined(_WIN32)
utils::Assert(maxfd < FD_SETSIZE, "maxdf must be smaller than FDSETSIZE");
#endif
if (timeout == 0) {
return select(maxfd, rfds, wfds, efds, NULL);
Socket::Error("Poll");
} else {
timeval tm;
tm.tv_usec = (timeout % 1000) * 1000;
tm.tv_sec = timeout / 1000;
return select(maxfd, rfds, wfds, efds, &tm);
for (auto& pfd : fdset) {
auto revents = pfd.revents & pfd.events;
if (!revents) {
fds.erase(pfd.fd);
} else {
fds[pfd.fd].events = revents;
}
}
}
}
SOCKET maxfd;
fd_set read_set, write_set, except_set;
std::unordered_map<SOCKET, pollfd> fds;
};
} // namespace utils
} // namespace rabit

View File

@ -4,6 +4,7 @@
all: model_recover_10_10k model_recover_10_10k_die_same model_recover_10_10k_die_hard local_recover_10_10k
# this experiment test recovery with actually process exit, use keepalive to keep program alive
# TODO: enable those tests once we fix issue in rabit
model_recover_10_10k:
../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0