diff --git a/src/engine_robust.cpp b/src/engine_robust.cpp index 9ed2a31fb..3382d189d 100644 --- a/src/engine_robust.cpp +++ b/src/engine_robust.cpp @@ -1,6 +1,6 @@ /*! * \file engine_robust.cpp - * \brief Robust implementation of AllReduce + * \brief Robust implementation of AllReduce * using TCP non-block socket and tree-shape reduction. * * This implementation considers the failure of nodes @@ -35,8 +35,10 @@ class AllReduceManager : public IEngine { // constant one byte out of band message to indicate error happening // and mark for channel cleanup const static char kOOBReset = 95; + // and mark for channel cleanup, after OOB signal + const static char kResetMark = 97; // and mark for channel cleanup - const static char kOOBResetAck = 97; + const static char kResetAck = 97; AllReduceManager(void) { master_uri = "NULL"; @@ -173,7 +175,6 @@ class AllReduceManager : public IEngine { size_t count, ReduceFunction reducer) { while (true) { - if (rank == rand() % 3) TryResetLinks(); ReturnType ret = TryAllReduce(sendrecvbuf_, type_nbytes, count, reducer); if (ret == kSuccess) return; if (ret == kSockError) { @@ -280,119 +281,95 @@ class AllReduceManager : public IEngine { for (int i = 0; i < nlink; ++i) { links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size); links[i].ResetSize(); + links[i].except = false; } - printf("[%d] start to reset link\n", rank); + // read and discard data from all channels until pass mark while (true) { - printf("[%d] loop\n", rank); - bool finished = true; for (int i = 0; i < nlink; ++i) { if (links[i].sock.BadSocket()) continue; if (links[i].size_write == 0) { char sig = kOOBReset; ssize_t len = links[i].sock.Send(&sig, sizeof(sig), MSG_OOB); // error will be filtered in next loop - if (len != -1) { - links[i].size_write += len; - printf("[%d] send OOB success\n", rank); + if (len == sizeof(sig)) links[i].size_write = 1; + } + if (links[i].size_write == 1) { + char sig = kResetMark; + ssize_t len = links[i].sock.Send(&sig, sizeof(sig)); + if (len == sizeof(sig)) links[i].size_write = 2; + } + if (links[i].size_read == 0) { + int atmark = links[i].sock.AtMark(); + if (atmark < 0) { + utils::Assert(links[i].sock.BadSocket(), "must already gone bad"); + } else if (atmark > 0) { + links[i].size_read = 1; + } else { + // no at mark, read and discard data + ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size); + // zero length, remote closed the connection, close socket + if (len == 0) links[i].sock.Close(); } } - // need to send OOB to every other link - if (links[i].size_write == 0) finished = false; } - if (finished) break; - } - printf("[%d] finish send all OOB\n", rank); - // wait for incoming except from all links - for (int i = 0; i < nlink; ++ i) { - if (links[i].sock.BadSocket()) continue; - printf("[%d] wait except\n", rank); - if (utils::SelectHelper::WaitExcept(links[i].sock) == -1) { - utils::Socket::Error("select"); - } - printf("[%d] finish wait except\n", rank); - } - printf("[%d] start to discard link\n", rank); - // read and discard data from all channels until pass mark - while (true) { utils::SelectHelper rsel; bool finished = true; for (int i = 0; i < nlink; ++i) { - if (links[i].sock.BadSocket()) continue; - if (links[i].size_read == 0) { - int atmark = links[i].sock.AtMark(); - if (atmark < 0) return kSockError; - if (atmark == 1) { - char oob_msg; - ssize_t len = links[i].sock.Recv(&oob_msg, sizeof(oob_msg), MSG_OOB); - if (len == -1 && errno != EAGAIN && errno != EWOULDBLOCK) { - finished = false; continue; - } - utils::Assert(oob_msg == kOOBReset, "wrong oob msg"); - links[i].size_read = 1; - } else { - finished = false; - rsel.WatchRead(links[i].sock); - } + if (links[i].size_write != 2 && !links[i].sock.BadSocket()) { + rsel.WatchWrite(links[i].sock); finished = false; + } + if (links[i].size_read == 0 && !links[i].sock.BadSocket()) { + rsel.WatchRead(links[i].sock); finished = false; } } if (finished) break; // wait to read from the channels to discard data rsel.Select(); - printf("[%d] select finish read from\n", rank); - for (int i = 0; i < nlink; ++i) { - if (links[i].sock.BadSocket()) continue; - if (rsel.CheckRead(links[i].sock)) { - ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size); - // zero length, remote closed the connection, close socket - if (len == 0) { - links[i].sock.Close(); - } else if (len == -1) { - // when error happens here, oob_clear will remember - if (errno == EAGAIN && errno == EWOULDBLOCK) printf("would block\n"); - } else { - printf("[%d] discard %ld bytes\n", rank, len); - } - } - } } - printf("[%d] discard all success\n", rank); - // start synchronization step + // start synchronization, use blocking I/O to avoid select for (int i = 0; i < nlink; ++i) { - links[i].ResetSize(); - } - while (true) { - // selecter for TryResetLinks - utils::SelectHelper rsel; - for (int i = 0; i < nlink; ++i) { - if (links[i].sock.BadSocket()) continue; - if (links[i].size_read == 0) rsel.WatchRead(links[i].sock); - if (links[i].size_write == 0) rsel.WatchWrite(links[i].sock); - } - printf("[%d] before select\n", rank); - rsel.Select(); - printf("[%d] after select\n", rank); - bool finished = true; - for (int i = 0; i < nlink; ++i) { - if (links[i].sock.BadSocket()) continue; - if (links[i].size_read == 0 && rsel.CheckRead(links[i].sock)) { - char ack; - links[i].ReadToArray(&ack, sizeof(ack)); - if (links[i].size_read != 0) { - utils::Assert(ack == kOOBResetAck, "expect ack message"); + if (!links[i].sock.BadSocket()) { + char oob_mark; + links[i].sock.SetNonBlock(false); + ssize_t len = links[i].sock.Recv(&oob_mark, sizeof(oob_mark), MSG_WAITALL); + if (len == 0) { + links[i].sock.Close(); continue; + } else if (len > 0) { + utils::Assert(oob_mark == kResetMark, "wrong oob msg"); + utils::Assert(!links[i].sock.AtMark(), "should already read past mark"); + } else { + utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); + } + // send out ack + char ack = kResetAck; + while (true) { + len = links[i].sock.Send(&ack, sizeof(ack)); + if (len == sizeof(ack)) break; + if (len == -1) { + if (errno != EAGAIN && errno != EWOULDBLOCK) break; } } - if (links[i].size_write == 0 && rsel.CheckWrite(links[i].sock)) { - char ack = kOOBResetAck; - links[i].WriteFromArray(&ack, sizeof(ack)); - } - if (links[i].size_read == 0 || links[i].size_write == 0) finished = false; } - if (finished) break; } - printf("[%d] after the read write data success\n", rank); + // wait all ack + for (int i = 0; i < nlink; ++i) { + if (!links[i].sock.BadSocket()) { + char ack; + ssize_t len = links[i].sock.Recv(&ack, sizeof(ack), MSG_WAITALL); + if (len == 0) { + links[i].sock.Close(); continue; + } else if (len > 0) { + utils::Assert(ack == kResetAck, "wrong Ack MSG"); + } else { + utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); + } + // set back to nonblock mode + links[i].sock.SetNonBlock(true); + } + } for (int i = 0; i < nlink; ++i) { if (links[i].sock.BadSocket()) return kSockError; - } + } return kSuccess; } // Run AllReduce, return if success @@ -540,9 +517,12 @@ class AllReduceManager : public IEngine { // pointer to buffer head char *buffer_head; // buffer size, in bytes - size_t buffer_size; + size_t buffer_size; + // exception + bool except; // constructor LinkRecord(void) {} + // initialize buffer inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) { size_t n = (type_nbytes * count + 7)/ 8; @@ -587,7 +567,7 @@ class AllReduceManager : public IEngine { * \return true if it is an successful read, false if there is some error happens, check errno */ inline bool ReadToArray(void *recvbuf_, size_t max_size) { - if (max_size == size_read ) return true; + if (max_size == size_read) return true; char *p = static_cast(recvbuf_); ssize_t len = sock.Recv(p + size_read, max_size - size_read); // length equals 0, remote disconnected diff --git a/test/test_allreduce.cpp b/test/test_allreduce.cpp index 9afdc6d03..40c85ea0b 100644 --- a/test/test_allreduce.cpp +++ b/test/test_allreduce.cpp @@ -76,9 +76,9 @@ int main(int argc, char *argv[]) { printf("[%d] start at %s\n", rank, name.c_str()); TestMax(mock, n); - printf("[%d] TestMax pass\n", rank); + printf("[%d] !!!TestMax pass\n", rank); TestSum(mock, n); - printf("[%d] TestSum pass\n", rank); + printf("[%d] !!!TestSum pass\n", rank); sync::Finalize(); printf("[%d] all check pass\n", rank); return 0;