diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index b905c2217..1736d8f6d 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -366,7 +366,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_, selecter.WatchException(links[i].sock); finished = false; } - if (size_up_out != total_size) { + if (size_up_out != total_size && size_up_out < size_up_reduce) { selecter.WatchWrite(links[i].sock); } } else { @@ -374,8 +374,10 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_, selecter.WatchRead(links[i].sock); } // size_write <= size_read - if (links[i].size_write != total_size) { - selecter.WatchWrite(links[i].sock); + if (links[i].size_write != total_size){ + if (links[i].size_write < size_down_in) { + selecter.WatchWrite(links[i].sock); + } // only watch for exception in live channels selecter.WatchException(links[i].sock); finished = false; @@ -439,7 +441,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_, } if (parent_index != -1) { // pass message up to parent, can pass data that are already been reduced - if (selecter.CheckWrite(links[parent_index].sock)) { + if (size_up_out < size_up_reduce) { ssize_t len = links[parent_index].sock. Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out); if (len != -1) { @@ -477,7 +479,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_, } // can pass message down to childs for (int i = 0; i < nlink; ++i) { - if (i != parent_index && selecter.CheckWrite(links[i].sock)) { + if (i != parent_index && links[i].size_write < size_down_in) { ReturnType ret = links[i].WriteFromArray(sendrecvbuf, size_down_in); if (ret != kSuccess) { return ReportError(&links[i], ret); @@ -530,7 +532,10 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { selecter.WatchRead(links[i].sock); finished = false; } if (in_link != -2 && i != in_link && links[i].size_write != total_size) { - selecter.WatchWrite(links[i].sock); finished = false; + if (links[i].size_write < size_in) { + selecter.WatchWrite(links[i].sock); + } + finished = false; } selecter.WatchException(links[i].sock); } @@ -571,11 +576,11 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { } // send data to all out-link for (int i = 0; i < nlink; ++i) { - if (i != in_link && selecter.CheckWrite(links[i].sock)) { + if (i != in_link && links[i].size_write < size_in) { ReturnType ret = links[i].WriteFromArray(sendrecvbuf_, size_in); if (ret != kSuccess) { return ReportError(&links[i], ret); - } + } } } } diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index 38cadbdb2..33fdcd0f0 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -645,8 +645,7 @@ AllreduceRobust::TryRecoverData(RecoverType role, } } for (int i = 0; i < nlink; ++i) { - if (req_in[i] && links[i].size_write != links[pid].size_read && - selecter.CheckWrite(links[i].sock)) { + if (req_in[i] && links[i].size_write != links[pid].size_read) { ReturnType ret = links[i].WriteFromArray(sendrecvbuf_, links[pid].size_read); if (ret != kSuccess) { return ReportError(&links[i], ret); @@ -656,7 +655,7 @@ AllreduceRobust::TryRecoverData(RecoverType role, } if (role == kHaveData) { for (int i = 0; i < nlink; ++i) { - if (req_in[i] && selecter.CheckWrite(links[i].sock)) { + if (req_in[i] && links[i].size_write != size) { ReturnType ret = links[i].WriteFromArray(sendrecvbuf_, size); if (ret != kSuccess) { return ReportError(&links[i], ret); @@ -679,8 +678,7 @@ AllreduceRobust::TryRecoverData(RecoverType role, } } for (int i = 0; i < nlink; ++i) { - if (req_in[i] && selecter.CheckWrite(links[i].sock) && - links[pid].size_read != links[i].size_write) { + if (req_in[i] && links[pid].size_read != links[i].size_write) { size_t start = links[i].size_write % buffer_size; // send out data from ring buffer size_t nwrite = std::min(buffer_size - start, links[pid].size_read - links[i].size_write); @@ -1162,8 +1160,7 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_, if (ret != kSuccess) return ReportError(&prev, ret); } } - if (write_ptr != write_end && write_ptr < read_ptr && - selecter.CheckWrite(next.sock)) { + if (write_ptr != write_end && write_ptr < read_ptr) { size_t nsend = std::min(write_end - write_ptr, read_ptr - write_ptr); ssize_t len = next.sock.Send(buf + write_ptr, nsend); if (len != -1) {