allreduce_robust.cc: Allow num_global_replica to be 0 (#38)
In some cases, users may not want to have any global replica of the data being broadcasted/all-reduced. In such cases, set the result_buffer_round to -1 as a flag that this is not necessary and check for it.
This commit is contained in:
parent
032152ad24
commit
21b5e12913
@ -33,8 +33,12 @@ AllreduceRobust::AllreduceRobust(void) {
|
|||||||
}
|
}
|
||||||
void AllreduceRobust::Init(int argc, char* argv[]) {
|
void AllreduceRobust::Init(int argc, char* argv[]) {
|
||||||
AllreduceBase::Init(argc, argv);
|
AllreduceBase::Init(argc, argv);
|
||||||
|
if (num_global_replica == 0) {
|
||||||
|
result_buffer_round = -1;
|
||||||
|
} else {
|
||||||
result_buffer_round = std::max(world_size / num_global_replica, 1);
|
result_buffer_round = std::max(world_size / num_global_replica, 1);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
/*! \brief shutdown the engine */
|
/*! \brief shutdown the engine */
|
||||||
void AllreduceRobust::Shutdown(void) {
|
void AllreduceRobust::Shutdown(void) {
|
||||||
// need to sync the exec before we shutdown, do a pesudo check point
|
// need to sync the exec before we shutdown, do a pesudo check point
|
||||||
@ -86,7 +90,8 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_,
|
|||||||
bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
|
bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
|
||||||
// now we are free to remove the last result, if any
|
// now we are free to remove the last result, if any
|
||||||
if (resbuf.LastSeqNo() != -1 &&
|
if (resbuf.LastSeqNo() != -1 &&
|
||||||
(resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
|
(result_buffer_round == -1 ||
|
||||||
|
resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
|
||||||
resbuf.DropLast();
|
resbuf.DropLast();
|
||||||
}
|
}
|
||||||
if (!recovered && prepare_fun != NULL) prepare_fun(prepare_arg);
|
if (!recovered && prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||||
@ -118,7 +123,8 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
|
|||||||
bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
|
bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
|
||||||
// now we are free to remove the last result, if any
|
// now we are free to remove the last result, if any
|
||||||
if (resbuf.LastSeqNo() != -1 &&
|
if (resbuf.LastSeqNo() != -1 &&
|
||||||
(resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
|
(result_buffer_round == -1 ||
|
||||||
|
resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
|
||||||
resbuf.DropLast();
|
resbuf.DropLast();
|
||||||
}
|
}
|
||||||
void *temp = resbuf.AllocTemp(1, total_size);
|
void *temp = resbuf.AllocTemp(1, total_size);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user