diff --git a/src/engine_robust.cc b/src/engine_robust.cc index 9db11bc96..ab33b0f0a 100644 --- a/src/engine_robust.cc +++ b/src/engine_robust.cc @@ -39,7 +39,6 @@ void AllReduceRobust::AllReduce(void *sendrecvbuf_, size_t count, ReduceFunction reducer) { bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter); - //utils::LogPrintf("[%d] AllReduce recovered=%d\n", rank, recovered); // now we are free to remove the last result, if any if (resbuf.LastSeqNo() != -1 && (resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) { @@ -442,14 +441,10 @@ AllReduceRobust::TryRecoverData(RecoverType role, // do not need to provide data or receive data, directly exit if (!req_data) return kSuccess; } - utils::LogPrintf("[%d] !!Need to pass data\n", rank); utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active"); for (int i = 0; i < nlink; ++i) { links[i].ResetSize(); } - if (role == kPassData) { - links[recv_link].InitBuffer(1, size, reduce_buffer_size); - } while (true) { bool finished = true; utils::SelectHelper selecter; @@ -457,9 +452,12 @@ AllReduceRobust::TryRecoverData(RecoverType role, if (i == recv_link && links[i].size_read != size) { selecter.WatchRead(links[i].sock); finished = false; - } + } if (req_in[i] && links[i].size_write != size) { - selecter.WatchWrite(links[i].sock); + if (role == kHaveData || + (role == kPassData && links[recv_link].size_read != links[i].size_write)) { + selecter.WatchWrite(links[i].sock); + } finished = false; } selecter.WatchException(links[i].sock); @@ -496,12 +494,12 @@ AllReduceRobust::TryRecoverData(RecoverType role, utils::Assert(min_write <= links[pid].size_read, "boundary check"); if (!links[pid].ReadToRingBuffer(min_write)) return kSockError; } - for (int i = 0; i < nlink; ++i) { - if (req_in[i] && selecter.CheckWrite(links[i].sock)) { + for (int i = 0; i < nlink; ++i) { + if (req_in[i] && selecter.CheckWrite(links[i].sock) && links[pid].size_read != links[i].size_write) { size_t start = links[i].size_write % buffer_size; // send out data from ring buffer - size_t nwrite = std::min(buffer_size - start, links[pid].size_read - links[i].size_write); - ssize_t len = links[pid].sock.Send(links[pid].buffer_head + start, nwrite); + size_t nwrite = std::min(buffer_size - start, links[pid].size_read - links[i].size_write); + ssize_t len = links[i].sock.Send(links[pid].buffer_head + start, nwrite); if (len != -1) { links[i].size_write += len; } else { @@ -559,7 +557,6 @@ AllReduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re } else { role = kRequestData; } - utils::LogPrintf("[%d] role=%d\n", rank, role); int recv_link; std::vector req_in; ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in); @@ -590,7 +587,6 @@ bool AllReduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) { } // request ActionSummary req(flag, seqno); - utils::LogPrintf("[%d] propose flag=%d, seq=%d\n", rank, flag, seqno); while (true) { // action ActionSummary act = req; diff --git a/test/test_recover.cpp b/test/test_recover.cpp index 9267cdca5..215177f20 100644 --- a/test/test_recover.cpp +++ b/test/test_recover.cpp @@ -19,6 +19,7 @@ inline void TestMax(test::Mock &mock, size_t n, int ntrial) { ndata[i] = (i * (rank+1)) % 111; } mock.AllReduce(&ndata[0], ndata.size()); + if (ntrial == 0 && rank == 15) throw MockException(); for (size_t i = 0; i < ndata.size(); ++i) { float rmax = (i * 1) % 111; for (int r = 0; r < nproc; ++r) {