diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index fc47982f7..a48a349a1 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -33,7 +33,11 @@ AllreduceRobust::AllreduceRobust(void) { } void AllreduceRobust::Init(int argc, char* argv[]) { AllreduceBase::Init(argc, argv); - result_buffer_round = std::max(world_size / num_global_replica, 1); + if (num_global_replica == 0) { + result_buffer_round = -1; + } else { + result_buffer_round = std::max(world_size / num_global_replica, 1); + } } /*! \brief shutdown the engine */ void AllreduceRobust::Shutdown(void) { @@ -86,7 +90,8 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_, bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter); // now we are free to remove the last result, if any if (resbuf.LastSeqNo() != -1 && - (resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) { + (result_buffer_round == -1 || + resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) { resbuf.DropLast(); } if (!recovered && prepare_fun != NULL) prepare_fun(prepare_arg); @@ -118,7 +123,8 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter); // now we are free to remove the last result, if any if (resbuf.LastSeqNo() != -1 && - (resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) { + (result_buffer_round == -1 || + resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) { resbuf.DropLast(); } void *temp = resbuf.AllocTemp(1, total_size);