diff --git a/src/engine_robust.cpp b/src/engine_robust.cpp index ad2ff34da..9ed2a31fb 100644 --- a/src/engine_robust.cpp +++ b/src/engine_robust.cpp @@ -35,6 +35,8 @@ class AllReduceManager : public IEngine { // constant one byte out of band message to indicate error happening // and mark for channel cleanup const static char kOOBReset = 95; + // and mark for channel cleanup + const static char kOOBResetAck = 97; AllReduceManager(void) { master_uri = "NULL"; @@ -148,13 +150,9 @@ class AllReduceManager : public IEngine { // close listening sockets sock_listen.Close(); // setup selecter - selecter.Clear(); for (size_t i = 0; i < links.size(); ++i) { // set the socket to non-blocking mode links[i].sock.SetNonBlock(true); - selecter.WatchRead(links[i].sock); - selecter.WatchWrite(links[i].sock); - selecter.WatchException(links[i].sock); } // done } @@ -211,7 +209,13 @@ class AllReduceManager : public IEngine { // while we have not passed the messages out while(true) { - selecter.Select(); + // select helper + utils::SelectHelper selecter; + for (size_t i = 0; i < links.size(); ++i) { + selecter.WatchRead(links[i].sock); + selecter.WatchWrite(links[i].sock); + selecter.WatchException(links[i].sock); + } if (in_link == -2) { // probe in-link for (int i = 0; i < nlink; ++i) { @@ -277,71 +281,118 @@ class AllReduceManager : public IEngine { links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size); links[i].ResetSize(); } - printf("[%d] start to reset link\n", rank); + printf("[%d] start to reset link\n", rank); while (true) { - if (selecter.Select() == -1) { - if (errno == EBADF || errno == EINTR) return kSockError; - utils::Socket::Error("select"); - } + printf("[%d] loop\n", rank); bool finished = true; for (int i = 0; i < nlink; ++i) { - if (selecter.CheckWrite(links[i].sock)) { - if (links[i].size_write == 0) { - char sig = kOOBReset; - ssize_t len = links[i].sock.Send(&sig, sizeof(sig), MSG_OOB); - if (len != -1) { - links[i].size_write += len; - } else { - if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError; - } + if (links[i].sock.BadSocket()) continue; + if (links[i].size_write == 0) { + char sig = kOOBReset; + ssize_t len = links[i].sock.Send(&sig, sizeof(sig), MSG_OOB); + // error will be filtered in next loop + if (len != -1) { + links[i].size_write += len; + printf("[%d] send OOB success\n", rank); } } // need to send OOB to every other link if (links[i].size_write == 0) finished = false; - // need to receive OOB from every link, or already cleanup some link - if (!links[i].oob_clear && !selecter.CheckExcept(links[i].sock)) finished = false; } if (finished) break; } + printf("[%d] finish send all OOB\n", rank); + // wait for incoming except from all links + for (int i = 0; i < nlink; ++ i) { + if (links[i].sock.BadSocket()) continue; + printf("[%d] wait except\n", rank); + if (utils::SelectHelper::WaitExcept(links[i].sock) == -1) { + utils::Socket::Error("select"); + } + printf("[%d] finish wait except\n", rank); + } printf("[%d] start to discard link\n", rank); // read and discard data from all channels until pass mark while (true) { - if (selecter.Select() == -1) { - if (errno == EBADF || errno == EINTR) return kSockError; - utils::Socket::Error("select"); - } + utils::SelectHelper rsel; bool finished = true; for (int i = 0; i < nlink; ++i) { - if (selecter.CheckExcept(links[i].sock)) { + if (links[i].sock.BadSocket()) continue; + if (links[i].size_read == 0) { int atmark = links[i].sock.AtMark(); if (atmark < 0) return kSockError; if (atmark == 1) { char oob_msg; ssize_t len = links[i].sock.Recv(&oob_msg, sizeof(oob_msg), MSG_OOB); - if (len == -1 && errno != EAGAIN && errno != EWOULDBLOCK) return kSockError; - utils::Assert(oob_msg == kOOBReset, "wrong oob msg"); - } else { - ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size); - if (len == -1) { - // when error happens here, oob_clear will remember - if (errno == EAGAIN && errno == EWOULDBLOCK) printf("would block\n"); - } else { - printf("[%d] discard %ld bytes\n", rank, len); + if (len == -1 && errno != EAGAIN && errno != EWOULDBLOCK) { + finished = false; continue; } - // the existing exception already cleared by this loop - if (len == -1 && errno != EAGAIN && errno != EWOULDBLOCK) return kSockError; + utils::Assert(oob_msg == kOOBReset, "wrong oob msg"); + links[i].size_read = 1; + } else { + finished = false; + rsel.WatchRead(links[i].sock); } - finished = false; - } else { - links[i].oob_clear = true; } } if (finished) break; + // wait to read from the channels to discard data + rsel.Select(); + printf("[%d] select finish read from\n", rank); + for (int i = 0; i < nlink; ++i) { + if (links[i].sock.BadSocket()) continue; + if (rsel.CheckRead(links[i].sock)) { + ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size); + // zero length, remote closed the connection, close socket + if (len == 0) { + links[i].sock.Close(); + } else if (len == -1) { + // when error happens here, oob_clear will remember + if (errno == EAGAIN && errno == EWOULDBLOCK) printf("would block\n"); + } else { + printf("[%d] discard %ld bytes\n", rank, len); + } + } + } } - // mark oob_clear mark as false + printf("[%d] discard all success\n", rank); + // start synchronization step for (int i = 0; i < nlink; ++i) { - links[i].oob_clear = false; + links[i].ResetSize(); } + while (true) { + // selecter for TryResetLinks + utils::SelectHelper rsel; + for (int i = 0; i < nlink; ++i) { + if (links[i].sock.BadSocket()) continue; + if (links[i].size_read == 0) rsel.WatchRead(links[i].sock); + if (links[i].size_write == 0) rsel.WatchWrite(links[i].sock); + } + printf("[%d] before select\n", rank); + rsel.Select(); + printf("[%d] after select\n", rank); + bool finished = true; + for (int i = 0; i < nlink; ++i) { + if (links[i].sock.BadSocket()) continue; + if (links[i].size_read == 0 && rsel.CheckRead(links[i].sock)) { + char ack; + links[i].ReadToArray(&ack, sizeof(ack)); + if (links[i].size_read != 0) { + utils::Assert(ack == kOOBResetAck, "expect ack message"); + } + } + if (links[i].size_write == 0 && rsel.CheckWrite(links[i].sock)) { + char ack = kOOBResetAck; + links[i].WriteFromArray(&ack, sizeof(ack)); + } + if (links[i].size_read == 0 || links[i].size_write == 0) finished = false; + } + if (finished) break; + } + printf("[%d] after the read write data success\n", rank); + for (int i = 0; i < nlink; ++i) { + if (links[i].sock.BadSocket()) return kSockError; + } return kSuccess; } // Run AllReduce, return if success @@ -376,10 +427,15 @@ class AllReduceManager : public IEngine { // while we have not passed the messages out while (true) { - if (selecter.Select() == -1) { - if (errno == EBADF || errno == EINTR) return kSockError; - utils::Socket::Error("select"); + // select helper + utils::SelectHelper selecter; + for (size_t i = 0; i < links.size(); ++i) { + selecter.WatchRead(links[i].sock); + selecter.WatchWrite(links[i].sock); + selecter.WatchException(links[i].sock); } + // select must return + selecter.Select(); // exception handling for (int i = 0; i < nlink; ++i) { // recive OOB message from some link @@ -437,9 +493,12 @@ class AllReduceManager : public IEngine { } } // read data from parent - if (selecter.CheckRead(links[parent_index].sock)) { + if (selecter.CheckRead(links[parent_index].sock) && total_size > size_down_in) { ssize_t len = links[parent_index].sock. Recv(sendrecvbuf + size_down_in, total_size - size_down_in); + if (len == 0) { + links[parent_index].sock.Close(); return kSockError; + } if (len != -1) { size_down_in += static_cast(len); utils::Assert(size_down_in <= size_up_out, "AllReduce: boundary error"); @@ -482,10 +541,8 @@ class AllReduceManager : public IEngine { char *buffer_head; // buffer size, in bytes size_t buffer_size; - // state used by TryResetLinks, whether a link is already cleaned from OOB mark - bool oob_clear; // constructor - LinkRecord(void) : oob_clear(false) {} + LinkRecord(void) {} // initialize buffer inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) { size_t n = (type_nbytes * count + 7)/ 8; @@ -511,8 +568,13 @@ class AllReduceManager : public IEngine { size_t ngap = size_read - protect_start; utils::Assert(ngap <= buffer_size, "AllReduce: boundary check"); size_t offset = size_read % buffer_size; - size_t nmax = std::min(buffer_size - ngap, buffer_size - offset); + size_t nmax = std::min(buffer_size - ngap, buffer_size - offset); + if (nmax == 0) return true; ssize_t len = sock.Recv(buffer_head + offset, nmax); + // length equals 0, remote disconnected + if (len == 0) { + sock.Close(); return false; + } if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK; size_read += static_cast(len); return true; @@ -525,8 +587,13 @@ class AllReduceManager : public IEngine { * \return true if it is an successful read, false if there is some error happens, check errno */ inline bool ReadToArray(void *recvbuf_, size_t max_size) { + if (max_size == size_read ) return true; char *p = static_cast(recvbuf_); ssize_t len = sock.Recv(p + size_read, max_size - size_read); + // length equals 0, remote disconnected + if (len == 0) { + sock.Close(); return false; + } if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK; size_read += static_cast(len); return true; @@ -613,8 +680,6 @@ class AllReduceManager : public IEngine { int parent_index; // sockets of all links std::vector links; - // select helper - utils::SelectHelper selecter; //----- meta information----- // uri of current host, to be set by Init std::string host_uri; diff --git a/src/socket.h b/src/socket.h index 9cbf1bcea..b9754f87d 100644 --- a/src/socket.h +++ b/src/socket.h @@ -164,6 +164,26 @@ class Socket { } return -1; } + /*! \brief get last error code if any */ + inline int GetSockError(void) const { + int error = 0; + socklen_t len = sizeof(error); + if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &error, &len) != 0) { + Error("GetSockError"); + } + return error; + } + /*! \brief check if anything bad happens */ + inline bool BadSocket(void) const { + if (IsClosed()) return true; + int err = GetSockError(); + if (err == EBADF || err == EINTR) return true; + return false; + } + /*! \brief check if socket is already closed */ + inline bool IsClosed(void) const { + return sockfd == INVALID_SOCKET; + } /*! \brief close the socket */ inline void Close(void) { if (sockfd != INVALID_SOCKET) { @@ -177,7 +197,6 @@ class Socket { Error("Socket::Close double close the socket or close without create"); } } - // report an socket error inline static void Error(const char *msg) { int errsv = errno; @@ -267,9 +286,8 @@ class TCPSocket : public Socket{ */ inline ssize_t Recv(void *buf_, size_t len, int flags = 0) { char *buf = reinterpret_cast(buf_); - if (len == 0) return 0; return recv(sockfd, buf, static_cast(len), flags); - } + } /*! * \brief peform block write that will attempt to send all data out * can still return smaller than request when error occurs @@ -319,14 +337,17 @@ class TCPSocket : public Socket{ struct SelectHelper { public: SelectHelper(void) { - this->Clear(); + FD_ZERO(&read_set); + FD_ZERO(&write_set); + FD_ZERO(&except_set); + maxfd = 0; } /*! * \brief add file descriptor to watch for read * \param fd file descriptor to be watched */ inline void WatchRead(SOCKET fd) { - read_fds.push_back(fd); + FD_SET(fd, &read_set); if (fd > maxfd) maxfd = fd; } /*! @@ -334,7 +355,7 @@ struct SelectHelper { * \param fd file descriptor to be watched */ inline void WatchWrite(SOCKET fd) { - write_fds.push_back(fd); + FD_SET(fd, &write_set); if (fd > maxfd) maxfd = fd; } /*! @@ -342,7 +363,7 @@ struct SelectHelper { * \param fd file descriptor to be watched */ inline void WatchException(SOCKET fd) { - except_fds.push_back(fd); + FD_SET(fd, &except_set); if (fd > maxfd) maxfd = fd; } /*! @@ -367,51 +388,49 @@ struct SelectHelper { return FD_ISSET(fd, &except_set) != 0; } /*! - * \brief clear all the monitored descriptors + * \brief wait for exception event on a single descriptor + * \param fd the file descriptor to wait the event for + * \param timeout the timeout counter, can be 0, which means wait until the event happen + * \return 1 if success, 0 if timeout, and -1 if error occurs */ - inline void Clear(void) { - read_fds.clear(); - write_fds.clear(); - except_fds.clear(); - maxfd = 0; - } + inline static int WaitExcept(SOCKET fd, long timeout = 0) { + fd_set wait_set; + FD_ZERO(&wait_set); + FD_SET(fd, &wait_set); + return Select_(static_cast(fd + 1), NULL, NULL, &wait_set, timeout); + } /*! * \brief peform select on the set defined + * \param select_read whether to watch for read event + * \param select_write whether to watch for write event + * \param select_except whether to watch for exception event * \param timeout specify timeout in micro-seconds(ms) if equals 0, means select will always block * \return number of active descriptors selected, * return -1 if error occurs */ inline int Select(long timeout = 0) { - FD_ZERO(&read_set); - FD_ZERO(&write_set); - FD_ZERO(&except_set); - for (size_t i = 0; i < read_fds.size(); ++i) { - FD_SET(read_fds[i], &read_set); - } - for (size_t i = 0; i < write_fds.size(); ++i) { - FD_SET(write_fds[i], &write_set); - } - for (size_t i = 0; i < except_fds.size(); ++i) { - FD_SET(except_fds[i], &except_set); - } - int ret; - if (timeout == 0) { - ret = select(static_cast(maxfd + 1), &read_set, - &write_set, &except_set, NULL); - } else { - timeval tm; - tm.tv_usec = (timeout % 1000) * 1000; - tm.tv_sec = timeout / 1000; - ret = select(static_cast(maxfd + 1), &read_set, - &write_set, &except_set, &tm); + int ret = Select_(static_cast(maxfd + 1), + &read_set, &write_set, &except_set, timeout); + if (ret == -1) { + Socket::Error("Select"); } return ret; } private: + inline static int Select_(int maxfd, fd_set *rfds, fd_set *wfds, fd_set *efds, long timeout) { + if (timeout == 0) { + return select(maxfd, rfds, wfds, efds, NULL); + } else { + timeval tm; + tm.tv_usec = (timeout % 1000) * 1000; + tm.tv_sec = timeout / 1000; + return select(maxfd, rfds, wfds, efds, &tm); + } + } + SOCKET maxfd; fd_set read_set, write_set, except_set; - std::vector read_fds, write_fds, except_fds; }; } #endif