allreduce_robust.cc: Allow num_global_replica to be 0 (#38)

In some cases, users may not want to have any global replica of
the data being broadcasted/all-reduced. In such cases, set the
result_buffer_round to -1 as a flag that this is not necessary
and check for it.
This commit is contained in:
AbdealiJK 2016-11-24 09:04:11 +05:30 committed by Tianqi Chen
parent 032152ad24
commit 21b5e12913

View File

@ -33,7 +33,11 @@ AllreduceRobust::AllreduceRobust(void) {
}
void AllreduceRobust::Init(int argc, char* argv[]) {
AllreduceBase::Init(argc, argv);
result_buffer_round = std::max(world_size / num_global_replica, 1);
if (num_global_replica == 0) {
result_buffer_round = -1;
} else {
result_buffer_round = std::max(world_size / num_global_replica, 1);
}
}
/*! \brief shutdown the engine */
void AllreduceRobust::Shutdown(void) {
@ -86,7 +90,8 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_,
bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
// now we are free to remove the last result, if any
if (resbuf.LastSeqNo() != -1 &&
(resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
(result_buffer_round == -1 ||
resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
resbuf.DropLast();
}
if (!recovered && prepare_fun != NULL) prepare_fun(prepare_arg);
@ -118,7 +123,8 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
// now we are free to remove the last result, if any
if (resbuf.LastSeqNo() != -1 &&
(resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
(result_buffer_round == -1 ||
resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
resbuf.DropLast();
}
void *temp = resbuf.AllocTemp(1, total_size);