diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index f2f75c19e..88a6dcace 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -249,7 +249,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { for (int i = 0; i < nlink; ++i) { all_links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size); all_links[i].ResetSize(); - } + } // read and discard data from all channels until pass mark while (true) { for (int i = 0; i < nlink; ++i) { @@ -283,7 +283,17 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { } } while (true) { + utils::SelectHelper rsel; + bool finished = true; for (int i = 0; i < nlink; ++i) { + if (all_links[i].size_read == 0 && !all_links[i].sock.BadSocket()) { + rsel.WatchRead(all_links[i].sock); finished = false; + } + } + if (finished) break; + rsel.Select(); + for (int i = 0; i < nlink; ++i) { + if (all_links[i].sock.BadSocket()) continue; if (all_links[i].size_read == 0) { int atmark = all_links[i].sock.AtMark(); if (atmark < 0) { @@ -299,17 +309,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { } } } - utils::SelectHelper rsel; - bool finished = true; - for (int i = 0; i < nlink; ++i) { - if (all_links[i].size_read == 0 && !all_links[i].sock.BadSocket()) { - rsel.WatchRead(all_links[i].sock); finished = false; - } - } - if (finished) break; - rsel.Select(); } - // start synchronization, use blocking I/O to avoid select for (int i = 0; i < nlink; ++i) { if (!all_links[i].sock.BadSocket()) { @@ -365,13 +365,15 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { */ bool AllreduceRobust::CheckAndRecover(ReturnType err_type) { if (err_type == kSuccess) return true; - // simple way, shutdown all links - for (size_t i = 0; i < all_links.size(); ++i) { - if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close(); + {// simple way, shutdown all links + for (size_t i = 0; i < all_links.size(); ++i) { + if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close(); + } + ReConnectLinks("recover"); + return false; } - ReConnectLinks("recover"); - return false; // this was old way + // TryResetLinks still causes possible errors, so not use this one while(err_type != kSuccess) { switch(err_type) { case kGetExcept: err_type = TryResetLinks(); break; diff --git a/src/allreduce_robust.h b/src/allreduce_robust.h index e43e9ac66..92c682b12 100644 --- a/src/allreduce_robust.h +++ b/src/allreduce_robust.h @@ -95,7 +95,11 @@ class AllreduceRobust : public AllreduceBase { * this function is only used for test purpose */ virtual void InitAfterException(void) { - this->CheckAndRecover(kGetExcept); + // simple way, shutdown all links + for (size_t i = 0; i < all_links.size(); ++i) { + if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close(); + } + ReConnectLinks("recover"); } private: