diff --git a/src/engine_base.cc b/src/engine_base.cc index 9f0aaa405..00ac1cffb 100644 --- a/src/engine_base.cc +++ b/src/engine_base.cc @@ -179,12 +179,28 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_, // while we have not passed the messages out while (true) { // select helper + bool finished = true; utils::SelectHelper selecter; - for (size_t i = 0; i < links.size(); ++i) { - selecter.WatchRead(links[i].sock); - selecter.WatchWrite(links[i].sock); + for (int i = 0; i < nlink; ++i) { + if (i == parent_index) { + if (size_down_in != total_size) { + selecter.WatchRead(links[i].sock); finished = false; + } + if (size_up_out != total_size) { + selecter.WatchWrite(links[i].sock); + } + } else { + if (links[i].size_read != total_size) { + selecter.WatchRead(links[i].sock); + } + if (links[i].size_write != total_size) { + selecter.WatchWrite(links[i].sock); finished = false; + } + } selecter.WatchException(links[i].sock); } + // finish runing allreduce + if (finished) break; // select must return selecter.Select(); // exception handling @@ -261,19 +277,12 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_, // this is root, can use reduce as most recent point size_down_in = size_up_out = size_up_reduce; } - // check if we finished the job of message passing - size_t nfinished = size_down_in; // can pass message down to childs for (int i = 0; i < nlink; ++i) { - if (i != parent_index) { - if (selecter.CheckWrite(links[i].sock)) { - if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) return kSockError; - } - nfinished = std::min(links[i].size_write, nfinished); + if (i != parent_index && selecter.CheckWrite(links[i].sock)) { + if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) return kSockError; } } - // check boundary condition - if (nfinished >= total_size) break; } return kSuccess; } @@ -288,6 +297,7 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_, AllReduceBase::ReturnType AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { if (links.size() == 0 || total_size == 0) return kSuccess; + utils::Check(root < world_size, "Broadcast: root should be smaller than world size"); // number of links const int nlink = static_cast(links.size()); // size of space already read from data @@ -306,13 +316,25 @@ AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { } // while we have not passed the messages out while(true) { + bool finished = true; // select helper utils::SelectHelper selecter; - for (size_t i = 0; i < links.size(); ++i) { - selecter.WatchRead(links[i].sock); - selecter.WatchWrite(links[i].sock); + for (int i = 0; i < nlink; ++i) { + if (in_link == -2) { + selecter.WatchRead(links[i].sock); finished = false; + } + if (i == in_link && links[i].size_read != total_size) { + selecter.WatchRead(links[i].sock); finished = false; + } + if (in_link != -2 && i != in_link && links[i].size_write != total_size) { + selecter.WatchWrite(links[i].sock); finished = false; + } selecter.WatchException(links[i].sock); } + // finish running + if (finished) break; + // select + selecter.Select(); // exception handling for (int i = 0; i < nlink; ++i) { // recive OOB message from some link @@ -336,18 +358,12 @@ AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { size_in = links[in_link].size_read; } } - size_t nfinished = total_size; // send data to all out-link for (int i = 0; i < nlink; ++i) { - if (i != in_link) { - if (selecter.CheckWrite(links[i].sock)) { - if (!links[i].WriteFromArray(sendrecvbuf_, size_in)) return kSockError; - } - nfinished = std::min(nfinished, links[i].size_write); + if (i != in_link && selecter.CheckWrite(links[i].sock)) { + if (!links[i].WriteFromArray(sendrecvbuf_, size_in)) return kSockError; } } - // check boundary condition - if (nfinished >= total_size) break; } return kSuccess; } diff --git a/src/engine_robust.cc b/src/engine_robust.cc index 7f510d2f3..9f03bea5e 100644 --- a/src/engine_robust.cc +++ b/src/engine_robust.cc @@ -366,10 +366,10 @@ AllReduceRobust::TryRecoverData(RecoverType role, // do not need to provide data or receive data, directly exit if (!req_data) return kSuccess; } + utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active"); for (int i = 0; i < nlink; ++i) { links[i].ResetSize(); } - utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active"); if (role == kPassData) { links[recv_link].InitBuffer(1, size, reduce_buffer_size); } diff --git a/test/test_allreduce.cpp b/test/test_allreduce.cpp index 3a2cc2a9d..02cb4057f 100644 --- a/test/test_allreduce.cpp +++ b/test/test_allreduce.cpp @@ -70,6 +70,7 @@ int main(int argc, char *argv[]) { int n = atoi(argv[1]); sync::Init(argc, argv); int rank = sync::GetRank(); + int nproc = sync::GetWorldSize(); std::string name = sync::GetProcessorName(); test::Mock mock(rank, argv[2], argv[3]); @@ -79,6 +80,10 @@ int main(int argc, char *argv[]) { utils::LogPrintf("[%d] !!!TestMax pass\n", rank); TestSum(mock, n); utils::LogPrintf("[%d] !!!TestSum pass\n", rank); + for (int i = 0; i < nproc; i += nproc / 3) { + TestBcast(mock, n, i); + } + utils::LogPrintf("[%d] !!!TestBcast pass\n", rank); sync::Finalize(); printf("[%d] all check pass\n", rank); return 0;