diff --git a/src/engine_robust.cc b/src/engine_robust.cc index b969bc9f6..00efd7447 100644 --- a/src/engine_robust.cc +++ b/src/engine_robust.cc @@ -17,18 +17,14 @@ void AllReduceRobust::AllReduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, ReduceFunction reducer) { - utils::LogPrintf("[%d] call AllReduce", rank); - TryResetLinks(); - utils::LogPrintf("[%d] start work", rank); while (true) { ReturnType ret = TryAllReduce(sendrecvbuf_, type_nbytes, count, reducer); if (ret == kSuccess) return; if (ret == kSockError) { utils::Error("error occur during all reduce\n"); } - utils::LogPrintf("[%d] receive except signal, start reset link", rank); + utils::LogPrintf("[%d] receive except signal, start reset link\n", rank); TryResetLinks(); - //utils::Check(TryResetLinks() == kSuccess, "error when reset links"); } // TODO } @@ -70,13 +66,13 @@ void AllReduceRobust::CheckPoint(const utils::ISerializable &model) { * and some link recovery proceduer is needed */ AllReduceRobust::ReturnType AllReduceRobust::TryResetLinks(void) { - utils::LogPrintf("[%d] TryResetLinks, start\n", rank); // number of links const int nlink = static_cast(links.size()); for (int i = 0; i < nlink; ++i) { links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size); links[i].ResetSize(); } + // read and discard data from all channels until pass mark while (true) { for (int i = 0; i < nlink; ++i) { @@ -92,20 +88,6 @@ AllReduceRobust::ReturnType AllReduceRobust::TryResetLinks(void) { ssize_t len = links[i].sock.Send(&sig, sizeof(sig)); if (len == sizeof(sig)) links[i].size_write = 2; } - if (links[i].size_read == 0) { - int atmark = links[i].sock.AtMark(); - if (atmark < 0) { - utils::Assert(links[i].sock.BadSocket(), "must already gone bad"); - } else if (atmark > 0) { - links[i].size_read = 1; - } else { - printf("buffer_size=%lu\n", links[i].buffer_size); - // no at mark, read and discard data - ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size); - // zero length, remote closed the connection, close socket - if (len == 0) links[i].sock.Close(); - } - } } utils::SelectHelper rsel; bool finished = true; @@ -113,15 +95,44 @@ AllReduceRobust::ReturnType AllReduceRobust::TryResetLinks(void) { if (links[i].size_write != 2 && !links[i].sock.BadSocket()) { rsel.WatchWrite(links[i].sock); finished = false; } - if (links[i].size_read == 0 && !links[i].sock.BadSocket()) { - rsel.WatchRead(links[i].sock); finished = false; - } } if (finished) break; // wait to read from the channels to discard data rsel.Select(); } - utils::LogPrintf("[%d] Finish discard data\n", rank); + for (int i = 0; i < nlink; ++i) { + if (!links[i].sock.BadSocket()) { + utils::SelectHelper::WaitExcept(links[i].sock); + } + } + while (true) { + for (int i = 0; i < nlink; ++i) { + if (links[i].size_read == 0) { + int atmark = links[i].sock.AtMark(); + if (atmark < 0) { + utils::Assert(links[i].sock.BadSocket(), "must already gone bad"); + } else if (atmark > 0) { + links[i].size_read = 1; + } else { + // no at mark, read and discard data + ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size); + if (links[i].sock.AtMark()) links[i].size_read = 1; + // zero length, remote closed the connection, close socket + if (len == 0) links[i].sock.Close(); + } + } + } + utils::SelectHelper rsel; + bool finished = true; + for (int i = 0; i < nlink; ++i) { + if (links[i].size_read == 0 && !links[i].sock.BadSocket()) { + rsel.WatchRead(links[i].sock); finished = false; + } + } + if (finished) break; + rsel.Select(); + } + // start synchronization, use blocking I/O to avoid select for (int i = 0; i < nlink; ++i) { if (!links[i].sock.BadSocket()) { @@ -132,7 +143,7 @@ AllReduceRobust::ReturnType AllReduceRobust::TryResetLinks(void) { links[i].sock.Close(); continue; } else if (len > 0) { utils::Assert(oob_mark == kResetMark, "wrong oob msg"); - utils::Assert(!links[i].sock.AtMark(), "should already read past mark"); + utils::Assert(links[i].sock.AtMark() != 1, "should already read past mark"); } else { utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); } @@ -147,7 +158,6 @@ AllReduceRobust::ReturnType AllReduceRobust::TryResetLinks(void) { } } } - utils::LogPrintf("[%d] GGet all Acks\n", rank); // wait all ack for (int i = 0; i < nlink; ++i) { if (!links[i].sock.BadSocket()) { @@ -167,7 +177,6 @@ AllReduceRobust::ReturnType AllReduceRobust::TryResetLinks(void) { for (int i = 0; i < nlink; ++i) { if (links[i].sock.BadSocket()) return kSockError; } - utils::LogPrintf("[%d] TryResetLinks,!! return\n", rank); return kSuccess; } diff --git a/src/socket.h b/src/socket.h index a5238a6c0..8f6d969e6 100644 --- a/src/socket.h +++ b/src/socket.h @@ -177,7 +177,7 @@ class Socket { inline bool BadSocket(void) const { if (IsClosed()) return true; int err = GetSockError(); - if (err == EBADF || err == EINTR) return true; + if (err == EBADF || err == EINTR) return true; return false; } /*! \brief check if socket is already closed */ @@ -250,7 +250,7 @@ class TCPSocket : public Socket{ int atmark; #ifdef _WIN32 if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1; -#else +#else if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1; #endif return atmark; @@ -418,6 +418,7 @@ struct SelectHelper { private: inline static int Select_(int maxfd, fd_set *rfds, fd_set *wfds, fd_set *efds, long timeout) { + utils::Assert(maxfd < FD_SETSIZE, "maxdf must be smaller than FDSETSIZE"); if (timeout == 0) { return select(maxfd, rfds, wfds, efds, NULL); } else { diff --git a/src/utils.h b/src/utils.h index 81bba7dfd..a371d6059 100644 --- a/src/utils.h +++ b/src/utils.h @@ -78,6 +78,7 @@ inline void HandlePrint(const char *msg) { } inline void HandleLogPrint(const char *msg) { fprintf(stderr, "%s", msg); + fflush(stderr); } #else #ifndef ALLREDUCE_STRICT_CXX98_