livelock in oob send recv

This commit is contained in:
tqchen 2014-11-28 21:56:58 -08:00
parent a30075794b
commit aa54a038f2
2 changed files with 173 additions and 89 deletions

View File

@ -35,6 +35,8 @@ class AllReduceManager : public IEngine {
// constant one byte out of band message to indicate error happening
// and mark for channel cleanup
const static char kOOBReset = 95;
// and mark for channel cleanup
const static char kOOBResetAck = 97;
AllReduceManager(void) {
master_uri = "NULL";
@ -148,13 +150,9 @@ class AllReduceManager : public IEngine {
// close listening sockets
sock_listen.Close();
// setup selecter
selecter.Clear();
for (size_t i = 0; i < links.size(); ++i) {
// set the socket to non-blocking mode
links[i].sock.SetNonBlock(true);
selecter.WatchRead(links[i].sock);
selecter.WatchWrite(links[i].sock);
selecter.WatchException(links[i].sock);
}
// done
}
@ -211,7 +209,13 @@ class AllReduceManager : public IEngine {
// while we have not passed the messages out
while(true) {
selecter.Select();
// select helper
utils::SelectHelper selecter;
for (size_t i = 0; i < links.size(); ++i) {
selecter.WatchRead(links[i].sock);
selecter.WatchWrite(links[i].sock);
selecter.WatchException(links[i].sock);
}
if (in_link == -2) {
// probe in-link
for (int i = 0; i < nlink; ++i) {
@ -277,71 +281,118 @@ class AllReduceManager : public IEngine {
links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size);
links[i].ResetSize();
}
printf("[%d] start to reset link\n", rank);
printf("[%d] start to reset link\n", rank);
while (true) {
if (selecter.Select() == -1) {
if (errno == EBADF || errno == EINTR) return kSockError;
utils::Socket::Error("select");
}
printf("[%d] loop\n", rank);
bool finished = true;
for (int i = 0; i < nlink; ++i) {
if (selecter.CheckWrite(links[i].sock)) {
if (links[i].size_write == 0) {
char sig = kOOBReset;
ssize_t len = links[i].sock.Send(&sig, sizeof(sig), MSG_OOB);
if (len != -1) {
links[i].size_write += len;
} else {
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
}
if (links[i].sock.BadSocket()) continue;
if (links[i].size_write == 0) {
char sig = kOOBReset;
ssize_t len = links[i].sock.Send(&sig, sizeof(sig), MSG_OOB);
// error will be filtered in next loop
if (len != -1) {
links[i].size_write += len;
printf("[%d] send OOB success\n", rank);
}
}
// need to send OOB to every other link
if (links[i].size_write == 0) finished = false;
// need to receive OOB from every link, or already cleanup some link
if (!links[i].oob_clear && !selecter.CheckExcept(links[i].sock)) finished = false;
}
if (finished) break;
}
printf("[%d] finish send all OOB\n", rank);
// wait for incoming except from all links
for (int i = 0; i < nlink; ++ i) {
if (links[i].sock.BadSocket()) continue;
printf("[%d] wait except\n", rank);
if (utils::SelectHelper::WaitExcept(links[i].sock) == -1) {
utils::Socket::Error("select");
}
printf("[%d] finish wait except\n", rank);
}
printf("[%d] start to discard link\n", rank);
// read and discard data from all channels until pass mark
while (true) {
if (selecter.Select() == -1) {
if (errno == EBADF || errno == EINTR) return kSockError;
utils::Socket::Error("select");
}
utils::SelectHelper rsel;
bool finished = true;
for (int i = 0; i < nlink; ++i) {
if (selecter.CheckExcept(links[i].sock)) {
if (links[i].sock.BadSocket()) continue;
if (links[i].size_read == 0) {
int atmark = links[i].sock.AtMark();
if (atmark < 0) return kSockError;
if (atmark == 1) {
char oob_msg;
ssize_t len = links[i].sock.Recv(&oob_msg, sizeof(oob_msg), MSG_OOB);
if (len == -1 && errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
utils::Assert(oob_msg == kOOBReset, "wrong oob msg");
} else {
ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size);
if (len == -1) {
// when error happens here, oob_clear will remember
if (errno == EAGAIN && errno == EWOULDBLOCK) printf("would block\n");
} else {
printf("[%d] discard %ld bytes\n", rank, len);
if (len == -1 && errno != EAGAIN && errno != EWOULDBLOCK) {
finished = false; continue;
}
// the existing exception already cleared by this loop
if (len == -1 && errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
utils::Assert(oob_msg == kOOBReset, "wrong oob msg");
links[i].size_read = 1;
} else {
finished = false;
rsel.WatchRead(links[i].sock);
}
finished = false;
} else {
links[i].oob_clear = true;
}
}
if (finished) break;
// wait to read from the channels to discard data
rsel.Select();
printf("[%d] select finish read from\n", rank);
for (int i = 0; i < nlink; ++i) {
if (links[i].sock.BadSocket()) continue;
if (rsel.CheckRead(links[i].sock)) {
ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size);
// zero length, remote closed the connection, close socket
if (len == 0) {
links[i].sock.Close();
} else if (len == -1) {
// when error happens here, oob_clear will remember
if (errno == EAGAIN && errno == EWOULDBLOCK) printf("would block\n");
} else {
printf("[%d] discard %ld bytes\n", rank, len);
}
}
}
}
// mark oob_clear mark as false
printf("[%d] discard all success\n", rank);
// start synchronization step
for (int i = 0; i < nlink; ++i) {
links[i].oob_clear = false;
links[i].ResetSize();
}
while (true) {
// selecter for TryResetLinks
utils::SelectHelper rsel;
for (int i = 0; i < nlink; ++i) {
if (links[i].sock.BadSocket()) continue;
if (links[i].size_read == 0) rsel.WatchRead(links[i].sock);
if (links[i].size_write == 0) rsel.WatchWrite(links[i].sock);
}
printf("[%d] before select\n", rank);
rsel.Select();
printf("[%d] after select\n", rank);
bool finished = true;
for (int i = 0; i < nlink; ++i) {
if (links[i].sock.BadSocket()) continue;
if (links[i].size_read == 0 && rsel.CheckRead(links[i].sock)) {
char ack;
links[i].ReadToArray(&ack, sizeof(ack));
if (links[i].size_read != 0) {
utils::Assert(ack == kOOBResetAck, "expect ack message");
}
}
if (links[i].size_write == 0 && rsel.CheckWrite(links[i].sock)) {
char ack = kOOBResetAck;
links[i].WriteFromArray(&ack, sizeof(ack));
}
if (links[i].size_read == 0 || links[i].size_write == 0) finished = false;
}
if (finished) break;
}
printf("[%d] after the read write data success\n", rank);
for (int i = 0; i < nlink; ++i) {
if (links[i].sock.BadSocket()) return kSockError;
}
return kSuccess;
}
// Run AllReduce, return if success
@ -376,10 +427,15 @@ class AllReduceManager : public IEngine {
// while we have not passed the messages out
while (true) {
if (selecter.Select() == -1) {
if (errno == EBADF || errno == EINTR) return kSockError;
utils::Socket::Error("select");
// select helper
utils::SelectHelper selecter;
for (size_t i = 0; i < links.size(); ++i) {
selecter.WatchRead(links[i].sock);
selecter.WatchWrite(links[i].sock);
selecter.WatchException(links[i].sock);
}
// select must return
selecter.Select();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
@ -437,9 +493,12 @@ class AllReduceManager : public IEngine {
}
}
// read data from parent
if (selecter.CheckRead(links[parent_index].sock)) {
if (selecter.CheckRead(links[parent_index].sock) && total_size > size_down_in) {
ssize_t len = links[parent_index].sock.
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
if (len == 0) {
links[parent_index].sock.Close(); return kSockError;
}
if (len != -1) {
size_down_in += static_cast<size_t>(len);
utils::Assert(size_down_in <= size_up_out, "AllReduce: boundary error");
@ -482,10 +541,8 @@ class AllReduceManager : public IEngine {
char *buffer_head;
// buffer size, in bytes
size_t buffer_size;
// state used by TryResetLinks, whether a link is already cleaned from OOB mark
bool oob_clear;
// constructor
LinkRecord(void) : oob_clear(false) {}
LinkRecord(void) {}
// initialize buffer
inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) {
size_t n = (type_nbytes * count + 7)/ 8;
@ -511,8 +568,13 @@ class AllReduceManager : public IEngine {
size_t ngap = size_read - protect_start;
utils::Assert(ngap <= buffer_size, "AllReduce: boundary check");
size_t offset = size_read % buffer_size;
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset);
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset);
if (nmax == 0) return true;
ssize_t len = sock.Recv(buffer_head + offset, nmax);
// length equals 0, remote disconnected
if (len == 0) {
sock.Close(); return false;
}
if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK;
size_read += static_cast<size_t>(len);
return true;
@ -525,8 +587,13 @@ class AllReduceManager : public IEngine {
* \return true if it is an successful read, false if there is some error happens, check errno
*/
inline bool ReadToArray(void *recvbuf_, size_t max_size) {
if (max_size == size_read ) return true;
char *p = static_cast<char*>(recvbuf_);
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
// length equals 0, remote disconnected
if (len == 0) {
sock.Close(); return false;
}
if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK;
size_read += static_cast<size_t>(len);
return true;
@ -613,8 +680,6 @@ class AllReduceManager : public IEngine {
int parent_index;
// sockets of all links
std::vector<LinkRecord> links;
// select helper
utils::SelectHelper selecter;
//----- meta information-----
// uri of current host, to be set by Init
std::string host_uri;

View File

@ -164,6 +164,26 @@ class Socket {
}
return -1;
}
/*! \brief get last error code if any */
inline int GetSockError(void) const {
int error = 0;
socklen_t len = sizeof(error);
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &error, &len) != 0) {
Error("GetSockError");
}
return error;
}
/*! \brief check if anything bad happens */
inline bool BadSocket(void) const {
if (IsClosed()) return true;
int err = GetSockError();
if (err == EBADF || err == EINTR) return true;
return false;
}
/*! \brief check if socket is already closed */
inline bool IsClosed(void) const {
return sockfd == INVALID_SOCKET;
}
/*! \brief close the socket */
inline void Close(void) {
if (sockfd != INVALID_SOCKET) {
@ -177,7 +197,6 @@ class Socket {
Error("Socket::Close double close the socket or close without create");
}
}
// report an socket error
inline static void Error(const char *msg) {
int errsv = errno;
@ -267,9 +286,8 @@ class TCPSocket : public Socket{
*/
inline ssize_t Recv(void *buf_, size_t len, int flags = 0) {
char *buf = reinterpret_cast<char*>(buf_);
if (len == 0) return 0;
return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
}
}
/*!
* \brief peform block write that will attempt to send all data out
* can still return smaller than request when error occurs
@ -319,14 +337,17 @@ class TCPSocket : public Socket{
struct SelectHelper {
public:
SelectHelper(void) {
this->Clear();
FD_ZERO(&read_set);
FD_ZERO(&write_set);
FD_ZERO(&except_set);
maxfd = 0;
}
/*!
* \brief add file descriptor to watch for read
* \param fd file descriptor to be watched
*/
inline void WatchRead(SOCKET fd) {
read_fds.push_back(fd);
FD_SET(fd, &read_set);
if (fd > maxfd) maxfd = fd;
}
/*!
@ -334,7 +355,7 @@ struct SelectHelper {
* \param fd file descriptor to be watched
*/
inline void WatchWrite(SOCKET fd) {
write_fds.push_back(fd);
FD_SET(fd, &write_set);
if (fd > maxfd) maxfd = fd;
}
/*!
@ -342,7 +363,7 @@ struct SelectHelper {
* \param fd file descriptor to be watched
*/
inline void WatchException(SOCKET fd) {
except_fds.push_back(fd);
FD_SET(fd, &except_set);
if (fd > maxfd) maxfd = fd;
}
/*!
@ -367,51 +388,49 @@ struct SelectHelper {
return FD_ISSET(fd, &except_set) != 0;
}
/*!
* \brief clear all the monitored descriptors
* \brief wait for exception event on a single descriptor
* \param fd the file descriptor to wait the event for
* \param timeout the timeout counter, can be 0, which means wait until the event happen
* \return 1 if success, 0 if timeout, and -1 if error occurs
*/
inline void Clear(void) {
read_fds.clear();
write_fds.clear();
except_fds.clear();
maxfd = 0;
}
inline static int WaitExcept(SOCKET fd, long timeout = 0) {
fd_set wait_set;
FD_ZERO(&wait_set);
FD_SET(fd, &wait_set);
return Select_(static_cast<int>(fd + 1), NULL, NULL, &wait_set, timeout);
}
/*!
* \brief peform select on the set defined
* \param select_read whether to watch for read event
* \param select_write whether to watch for write event
* \param select_except whether to watch for exception event
* \param timeout specify timeout in micro-seconds(ms) if equals 0, means select will always block
* \return number of active descriptors selected,
* return -1 if error occurs
*/
inline int Select(long timeout = 0) {
FD_ZERO(&read_set);
FD_ZERO(&write_set);
FD_ZERO(&except_set);
for (size_t i = 0; i < read_fds.size(); ++i) {
FD_SET(read_fds[i], &read_set);
}
for (size_t i = 0; i < write_fds.size(); ++i) {
FD_SET(write_fds[i], &write_set);
}
for (size_t i = 0; i < except_fds.size(); ++i) {
FD_SET(except_fds[i], &except_set);
}
int ret;
if (timeout == 0) {
ret = select(static_cast<int>(maxfd + 1), &read_set,
&write_set, &except_set, NULL);
} else {
timeval tm;
tm.tv_usec = (timeout % 1000) * 1000;
tm.tv_sec = timeout / 1000;
ret = select(static_cast<int>(maxfd + 1), &read_set,
&write_set, &except_set, &tm);
int ret = Select_(static_cast<int>(maxfd + 1),
&read_set, &write_set, &except_set, timeout);
if (ret == -1) {
Socket::Error("Select");
}
return ret;
}
private:
inline static int Select_(int maxfd, fd_set *rfds, fd_set *wfds, fd_set *efds, long timeout) {
if (timeout == 0) {
return select(maxfd, rfds, wfds, efds, NULL);
} else {
timeval tm;
tm.tv_usec = (timeout % 1000) * 1000;
tm.tv_sec = timeout / 1000;
return select(maxfd, rfds, wfds, efds, &tm);
}
}
SOCKET maxfd;
fd_set read_set, write_set, except_set;
std::vector<SOCKET> read_fds, write_fds, except_fds;
};
}
#endif