diff --git a/src/allreduce_base.h b/src/allreduce_base.h index 4194beb13..ef5567af1 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -90,7 +90,7 @@ class AllreduceBase : public IEngine { PreprocFunction prepare_fun = NULL, void *prepare_arg = NULL) { if (prepare_fun != NULL) prepare_fun(prepare_arg); - if (world_size == 1) return; + if (world_size == 1 || world_size == -1) return; utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) == kSuccess, "Allreduce failed"); @@ -102,7 +102,7 @@ class AllreduceBase : public IEngine { * \param root the root worker id to broadcast the data */ virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) { - if (world_size == 1) return; + if (world_size == 1 || world_size == -1) return; utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess, "Broadcast failed"); } diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index c89b69542..fc47982f7 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -114,7 +114,7 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_, */ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) { // skip action in single node - if (world_size == 1) return; + if (world_size == 1 || world_size == -1) return; bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter); // now we are free to remove the last result, if any if (resbuf.LastSeqNo() != -1 &&