livelock in oob send recv
This commit is contained in:
parent
a30075794b
commit
aa54a038f2
@ -35,6 +35,8 @@ class AllReduceManager : public IEngine {
|
||||
// constant one byte out of band message to indicate error happening
|
||||
// and mark for channel cleanup
|
||||
const static char kOOBReset = 95;
|
||||
// and mark for channel cleanup
|
||||
const static char kOOBResetAck = 97;
|
||||
|
||||
AllReduceManager(void) {
|
||||
master_uri = "NULL";
|
||||
@ -148,13 +150,9 @@ class AllReduceManager : public IEngine {
|
||||
// close listening sockets
|
||||
sock_listen.Close();
|
||||
// setup selecter
|
||||
selecter.Clear();
|
||||
for (size_t i = 0; i < links.size(); ++i) {
|
||||
// set the socket to non-blocking mode
|
||||
links[i].sock.SetNonBlock(true);
|
||||
selecter.WatchRead(links[i].sock);
|
||||
selecter.WatchWrite(links[i].sock);
|
||||
selecter.WatchException(links[i].sock);
|
||||
}
|
||||
// done
|
||||
}
|
||||
@ -211,7 +209,13 @@ class AllReduceManager : public IEngine {
|
||||
|
||||
// while we have not passed the messages out
|
||||
while(true) {
|
||||
selecter.Select();
|
||||
// select helper
|
||||
utils::SelectHelper selecter;
|
||||
for (size_t i = 0; i < links.size(); ++i) {
|
||||
selecter.WatchRead(links[i].sock);
|
||||
selecter.WatchWrite(links[i].sock);
|
||||
selecter.WatchException(links[i].sock);
|
||||
}
|
||||
if (in_link == -2) {
|
||||
// probe in-link
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
@ -277,71 +281,118 @@ class AllReduceManager : public IEngine {
|
||||
links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size);
|
||||
links[i].ResetSize();
|
||||
}
|
||||
printf("[%d] start to reset link\n", rank);
|
||||
printf("[%d] start to reset link\n", rank);
|
||||
while (true) {
|
||||
if (selecter.Select() == -1) {
|
||||
if (errno == EBADF || errno == EINTR) return kSockError;
|
||||
utils::Socket::Error("select");
|
||||
}
|
||||
printf("[%d] loop\n", rank);
|
||||
bool finished = true;
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (selecter.CheckWrite(links[i].sock)) {
|
||||
if (links[i].size_write == 0) {
|
||||
char sig = kOOBReset;
|
||||
ssize_t len = links[i].sock.Send(&sig, sizeof(sig), MSG_OOB);
|
||||
if (len != -1) {
|
||||
links[i].size_write += len;
|
||||
} else {
|
||||
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
|
||||
}
|
||||
if (links[i].sock.BadSocket()) continue;
|
||||
if (links[i].size_write == 0) {
|
||||
char sig = kOOBReset;
|
||||
ssize_t len = links[i].sock.Send(&sig, sizeof(sig), MSG_OOB);
|
||||
// error will be filtered in next loop
|
||||
if (len != -1) {
|
||||
links[i].size_write += len;
|
||||
printf("[%d] send OOB success\n", rank);
|
||||
}
|
||||
}
|
||||
// need to send OOB to every other link
|
||||
if (links[i].size_write == 0) finished = false;
|
||||
// need to receive OOB from every link, or already cleanup some link
|
||||
if (!links[i].oob_clear && !selecter.CheckExcept(links[i].sock)) finished = false;
|
||||
}
|
||||
if (finished) break;
|
||||
}
|
||||
printf("[%d] finish send all OOB\n", rank);
|
||||
// wait for incoming except from all links
|
||||
for (int i = 0; i < nlink; ++ i) {
|
||||
if (links[i].sock.BadSocket()) continue;
|
||||
printf("[%d] wait except\n", rank);
|
||||
if (utils::SelectHelper::WaitExcept(links[i].sock) == -1) {
|
||||
utils::Socket::Error("select");
|
||||
}
|
||||
printf("[%d] finish wait except\n", rank);
|
||||
}
|
||||
printf("[%d] start to discard link\n", rank);
|
||||
// read and discard data from all channels until pass mark
|
||||
while (true) {
|
||||
if (selecter.Select() == -1) {
|
||||
if (errno == EBADF || errno == EINTR) return kSockError;
|
||||
utils::Socket::Error("select");
|
||||
}
|
||||
utils::SelectHelper rsel;
|
||||
bool finished = true;
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (selecter.CheckExcept(links[i].sock)) {
|
||||
if (links[i].sock.BadSocket()) continue;
|
||||
if (links[i].size_read == 0) {
|
||||
int atmark = links[i].sock.AtMark();
|
||||
if (atmark < 0) return kSockError;
|
||||
if (atmark == 1) {
|
||||
char oob_msg;
|
||||
ssize_t len = links[i].sock.Recv(&oob_msg, sizeof(oob_msg), MSG_OOB);
|
||||
if (len == -1 && errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
|
||||
utils::Assert(oob_msg == kOOBReset, "wrong oob msg");
|
||||
} else {
|
||||
ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size);
|
||||
if (len == -1) {
|
||||
// when error happens here, oob_clear will remember
|
||||
if (errno == EAGAIN && errno == EWOULDBLOCK) printf("would block\n");
|
||||
} else {
|
||||
printf("[%d] discard %ld bytes\n", rank, len);
|
||||
if (len == -1 && errno != EAGAIN && errno != EWOULDBLOCK) {
|
||||
finished = false; continue;
|
||||
}
|
||||
// the existing exception already cleared by this loop
|
||||
if (len == -1 && errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
|
||||
utils::Assert(oob_msg == kOOBReset, "wrong oob msg");
|
||||
links[i].size_read = 1;
|
||||
} else {
|
||||
finished = false;
|
||||
rsel.WatchRead(links[i].sock);
|
||||
}
|
||||
finished = false;
|
||||
} else {
|
||||
links[i].oob_clear = true;
|
||||
}
|
||||
}
|
||||
if (finished) break;
|
||||
// wait to read from the channels to discard data
|
||||
rsel.Select();
|
||||
printf("[%d] select finish read from\n", rank);
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (links[i].sock.BadSocket()) continue;
|
||||
if (rsel.CheckRead(links[i].sock)) {
|
||||
ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size);
|
||||
// zero length, remote closed the connection, close socket
|
||||
if (len == 0) {
|
||||
links[i].sock.Close();
|
||||
} else if (len == -1) {
|
||||
// when error happens here, oob_clear will remember
|
||||
if (errno == EAGAIN && errno == EWOULDBLOCK) printf("would block\n");
|
||||
} else {
|
||||
printf("[%d] discard %ld bytes\n", rank, len);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// mark oob_clear mark as false
|
||||
printf("[%d] discard all success\n", rank);
|
||||
// start synchronization step
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
links[i].oob_clear = false;
|
||||
links[i].ResetSize();
|
||||
}
|
||||
while (true) {
|
||||
// selecter for TryResetLinks
|
||||
utils::SelectHelper rsel;
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (links[i].sock.BadSocket()) continue;
|
||||
if (links[i].size_read == 0) rsel.WatchRead(links[i].sock);
|
||||
if (links[i].size_write == 0) rsel.WatchWrite(links[i].sock);
|
||||
}
|
||||
printf("[%d] before select\n", rank);
|
||||
rsel.Select();
|
||||
printf("[%d] after select\n", rank);
|
||||
bool finished = true;
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (links[i].sock.BadSocket()) continue;
|
||||
if (links[i].size_read == 0 && rsel.CheckRead(links[i].sock)) {
|
||||
char ack;
|
||||
links[i].ReadToArray(&ack, sizeof(ack));
|
||||
if (links[i].size_read != 0) {
|
||||
utils::Assert(ack == kOOBResetAck, "expect ack message");
|
||||
}
|
||||
}
|
||||
if (links[i].size_write == 0 && rsel.CheckWrite(links[i].sock)) {
|
||||
char ack = kOOBResetAck;
|
||||
links[i].WriteFromArray(&ack, sizeof(ack));
|
||||
}
|
||||
if (links[i].size_read == 0 || links[i].size_write == 0) finished = false;
|
||||
}
|
||||
if (finished) break;
|
||||
}
|
||||
printf("[%d] after the read write data success\n", rank);
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (links[i].sock.BadSocket()) return kSockError;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
// Run AllReduce, return if success
|
||||
@ -376,10 +427,15 @@ class AllReduceManager : public IEngine {
|
||||
|
||||
// while we have not passed the messages out
|
||||
while (true) {
|
||||
if (selecter.Select() == -1) {
|
||||
if (errno == EBADF || errno == EINTR) return kSockError;
|
||||
utils::Socket::Error("select");
|
||||
// select helper
|
||||
utils::SelectHelper selecter;
|
||||
for (size_t i = 0; i < links.size(); ++i) {
|
||||
selecter.WatchRead(links[i].sock);
|
||||
selecter.WatchWrite(links[i].sock);
|
||||
selecter.WatchException(links[i].sock);
|
||||
}
|
||||
// select must return
|
||||
selecter.Select();
|
||||
// exception handling
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
// recive OOB message from some link
|
||||
@ -437,9 +493,12 @@ class AllReduceManager : public IEngine {
|
||||
}
|
||||
}
|
||||
// read data from parent
|
||||
if (selecter.CheckRead(links[parent_index].sock)) {
|
||||
if (selecter.CheckRead(links[parent_index].sock) && total_size > size_down_in) {
|
||||
ssize_t len = links[parent_index].sock.
|
||||
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
|
||||
if (len == 0) {
|
||||
links[parent_index].sock.Close(); return kSockError;
|
||||
}
|
||||
if (len != -1) {
|
||||
size_down_in += static_cast<size_t>(len);
|
||||
utils::Assert(size_down_in <= size_up_out, "AllReduce: boundary error");
|
||||
@ -482,10 +541,8 @@ class AllReduceManager : public IEngine {
|
||||
char *buffer_head;
|
||||
// buffer size, in bytes
|
||||
size_t buffer_size;
|
||||
// state used by TryResetLinks, whether a link is already cleaned from OOB mark
|
||||
bool oob_clear;
|
||||
// constructor
|
||||
LinkRecord(void) : oob_clear(false) {}
|
||||
LinkRecord(void) {}
|
||||
// initialize buffer
|
||||
inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) {
|
||||
size_t n = (type_nbytes * count + 7)/ 8;
|
||||
@ -511,8 +568,13 @@ class AllReduceManager : public IEngine {
|
||||
size_t ngap = size_read - protect_start;
|
||||
utils::Assert(ngap <= buffer_size, "AllReduce: boundary check");
|
||||
size_t offset = size_read % buffer_size;
|
||||
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset);
|
||||
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset);
|
||||
if (nmax == 0) return true;
|
||||
ssize_t len = sock.Recv(buffer_head + offset, nmax);
|
||||
// length equals 0, remote disconnected
|
||||
if (len == 0) {
|
||||
sock.Close(); return false;
|
||||
}
|
||||
if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK;
|
||||
size_read += static_cast<size_t>(len);
|
||||
return true;
|
||||
@ -525,8 +587,13 @@ class AllReduceManager : public IEngine {
|
||||
* \return true if it is an successful read, false if there is some error happens, check errno
|
||||
*/
|
||||
inline bool ReadToArray(void *recvbuf_, size_t max_size) {
|
||||
if (max_size == size_read ) return true;
|
||||
char *p = static_cast<char*>(recvbuf_);
|
||||
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
|
||||
// length equals 0, remote disconnected
|
||||
if (len == 0) {
|
||||
sock.Close(); return false;
|
||||
}
|
||||
if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK;
|
||||
size_read += static_cast<size_t>(len);
|
||||
return true;
|
||||
@ -613,8 +680,6 @@ class AllReduceManager : public IEngine {
|
||||
int parent_index;
|
||||
// sockets of all links
|
||||
std::vector<LinkRecord> links;
|
||||
// select helper
|
||||
utils::SelectHelper selecter;
|
||||
//----- meta information-----
|
||||
// uri of current host, to be set by Init
|
||||
std::string host_uri;
|
||||
|
||||
93
src/socket.h
93
src/socket.h
@ -164,6 +164,26 @@ class Socket {
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
/*! \brief get last error code if any */
|
||||
inline int GetSockError(void) const {
|
||||
int error = 0;
|
||||
socklen_t len = sizeof(error);
|
||||
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &error, &len) != 0) {
|
||||
Error("GetSockError");
|
||||
}
|
||||
return error;
|
||||
}
|
||||
/*! \brief check if anything bad happens */
|
||||
inline bool BadSocket(void) const {
|
||||
if (IsClosed()) return true;
|
||||
int err = GetSockError();
|
||||
if (err == EBADF || err == EINTR) return true;
|
||||
return false;
|
||||
}
|
||||
/*! \brief check if socket is already closed */
|
||||
inline bool IsClosed(void) const {
|
||||
return sockfd == INVALID_SOCKET;
|
||||
}
|
||||
/*! \brief close the socket */
|
||||
inline void Close(void) {
|
||||
if (sockfd != INVALID_SOCKET) {
|
||||
@ -177,7 +197,6 @@ class Socket {
|
||||
Error("Socket::Close double close the socket or close without create");
|
||||
}
|
||||
}
|
||||
|
||||
// report an socket error
|
||||
inline static void Error(const char *msg) {
|
||||
int errsv = errno;
|
||||
@ -267,9 +286,8 @@ class TCPSocket : public Socket{
|
||||
*/
|
||||
inline ssize_t Recv(void *buf_, size_t len, int flags = 0) {
|
||||
char *buf = reinterpret_cast<char*>(buf_);
|
||||
if (len == 0) return 0;
|
||||
return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief peform block write that will attempt to send all data out
|
||||
* can still return smaller than request when error occurs
|
||||
@ -319,14 +337,17 @@ class TCPSocket : public Socket{
|
||||
struct SelectHelper {
|
||||
public:
|
||||
SelectHelper(void) {
|
||||
this->Clear();
|
||||
FD_ZERO(&read_set);
|
||||
FD_ZERO(&write_set);
|
||||
FD_ZERO(&except_set);
|
||||
maxfd = 0;
|
||||
}
|
||||
/*!
|
||||
* \brief add file descriptor to watch for read
|
||||
* \param fd file descriptor to be watched
|
||||
*/
|
||||
inline void WatchRead(SOCKET fd) {
|
||||
read_fds.push_back(fd);
|
||||
FD_SET(fd, &read_set);
|
||||
if (fd > maxfd) maxfd = fd;
|
||||
}
|
||||
/*!
|
||||
@ -334,7 +355,7 @@ struct SelectHelper {
|
||||
* \param fd file descriptor to be watched
|
||||
*/
|
||||
inline void WatchWrite(SOCKET fd) {
|
||||
write_fds.push_back(fd);
|
||||
FD_SET(fd, &write_set);
|
||||
if (fd > maxfd) maxfd = fd;
|
||||
}
|
||||
/*!
|
||||
@ -342,7 +363,7 @@ struct SelectHelper {
|
||||
* \param fd file descriptor to be watched
|
||||
*/
|
||||
inline void WatchException(SOCKET fd) {
|
||||
except_fds.push_back(fd);
|
||||
FD_SET(fd, &except_set);
|
||||
if (fd > maxfd) maxfd = fd;
|
||||
}
|
||||
/*!
|
||||
@ -367,51 +388,49 @@ struct SelectHelper {
|
||||
return FD_ISSET(fd, &except_set) != 0;
|
||||
}
|
||||
/*!
|
||||
* \brief clear all the monitored descriptors
|
||||
* \brief wait for exception event on a single descriptor
|
||||
* \param fd the file descriptor to wait the event for
|
||||
* \param timeout the timeout counter, can be 0, which means wait until the event happen
|
||||
* \return 1 if success, 0 if timeout, and -1 if error occurs
|
||||
*/
|
||||
inline void Clear(void) {
|
||||
read_fds.clear();
|
||||
write_fds.clear();
|
||||
except_fds.clear();
|
||||
maxfd = 0;
|
||||
}
|
||||
inline static int WaitExcept(SOCKET fd, long timeout = 0) {
|
||||
fd_set wait_set;
|
||||
FD_ZERO(&wait_set);
|
||||
FD_SET(fd, &wait_set);
|
||||
return Select_(static_cast<int>(fd + 1), NULL, NULL, &wait_set, timeout);
|
||||
}
|
||||
/*!
|
||||
* \brief peform select on the set defined
|
||||
* \param select_read whether to watch for read event
|
||||
* \param select_write whether to watch for write event
|
||||
* \param select_except whether to watch for exception event
|
||||
* \param timeout specify timeout in micro-seconds(ms) if equals 0, means select will always block
|
||||
* \return number of active descriptors selected,
|
||||
* return -1 if error occurs
|
||||
*/
|
||||
inline int Select(long timeout = 0) {
|
||||
FD_ZERO(&read_set);
|
||||
FD_ZERO(&write_set);
|
||||
FD_ZERO(&except_set);
|
||||
for (size_t i = 0; i < read_fds.size(); ++i) {
|
||||
FD_SET(read_fds[i], &read_set);
|
||||
}
|
||||
for (size_t i = 0; i < write_fds.size(); ++i) {
|
||||
FD_SET(write_fds[i], &write_set);
|
||||
}
|
||||
for (size_t i = 0; i < except_fds.size(); ++i) {
|
||||
FD_SET(except_fds[i], &except_set);
|
||||
}
|
||||
int ret;
|
||||
if (timeout == 0) {
|
||||
ret = select(static_cast<int>(maxfd + 1), &read_set,
|
||||
&write_set, &except_set, NULL);
|
||||
} else {
|
||||
timeval tm;
|
||||
tm.tv_usec = (timeout % 1000) * 1000;
|
||||
tm.tv_sec = timeout / 1000;
|
||||
ret = select(static_cast<int>(maxfd + 1), &read_set,
|
||||
&write_set, &except_set, &tm);
|
||||
int ret = Select_(static_cast<int>(maxfd + 1),
|
||||
&read_set, &write_set, &except_set, timeout);
|
||||
if (ret == -1) {
|
||||
Socket::Error("Select");
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
private:
|
||||
inline static int Select_(int maxfd, fd_set *rfds, fd_set *wfds, fd_set *efds, long timeout) {
|
||||
if (timeout == 0) {
|
||||
return select(maxfd, rfds, wfds, efds, NULL);
|
||||
} else {
|
||||
timeval tm;
|
||||
tm.tv_usec = (timeout % 1000) * 1000;
|
||||
tm.tv_sec = timeout / 1000;
|
||||
return select(maxfd, rfds, wfds, efds, &tm);
|
||||
}
|
||||
}
|
||||
|
||||
SOCKET maxfd;
|
||||
fd_set read_set, write_set, except_set;
|
||||
std::vector<SOCKET> read_fds, write_fds, except_fds;
|
||||
};
|
||||
}
|
||||
#endif
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user