diff --git a/src/engine_base.cc b/src/engine_base.cc index 00ac1cffb..eec2330fc 100644 --- a/src/engine_base.cc +++ b/src/engine_base.cc @@ -174,8 +174,7 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_, // if no childs, no need to reduce if (nlink == static_cast(parent_index != -1)) { size_up_reduce = total_size; - } - + } // while we have not passed the messages out while (true) { // select helper @@ -184,7 +183,10 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_, for (int i = 0; i < nlink; ++i) { if (i == parent_index) { if (size_down_in != total_size) { - selecter.WatchRead(links[i].sock); finished = false; + selecter.WatchRead(links[i].sock); + // only watch for exception in live channels + selecter.WatchException(links[i].sock); + finished = false; } if (size_up_out != total_size) { selecter.WatchWrite(links[i].sock); @@ -193,11 +195,15 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_, if (links[i].size_read != total_size) { selecter.WatchRead(links[i].sock); } + // size_write <= size_read if (links[i].size_write != total_size) { - selecter.WatchWrite(links[i].sock); finished = false; + selecter.WatchWrite(links[i].sock); + // only watch for exception in live channels + selecter.WatchException(links[i].sock); + finished = false; } } - selecter.WatchException(links[i].sock); + } // finish runing allreduce if (finished) break;