smarter select for allreduce and bcast

This commit is contained in:
tqchen 2014-11-30 21:31:45 -08:00
parent f7928c68a3
commit 8cef2086f5
3 changed files with 45 additions and 24 deletions

View File

@ -179,12 +179,28 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_,
// while we have not passed the messages out
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
for (size_t i = 0; i < links.size(); ++i) {
selecter.WatchRead(links[i].sock);
selecter.WatchWrite(links[i].sock);
for (int i = 0; i < nlink; ++i) {
if (i == parent_index) {
if (size_down_in != total_size) {
selecter.WatchRead(links[i].sock); finished = false;
}
if (size_up_out != total_size) {
selecter.WatchWrite(links[i].sock);
}
} else {
if (links[i].size_read != total_size) {
selecter.WatchRead(links[i].sock);
}
if (links[i].size_write != total_size) {
selecter.WatchWrite(links[i].sock); finished = false;
}
}
selecter.WatchException(links[i].sock);
}
// finish runing allreduce
if (finished) break;
// select must return
selecter.Select();
// exception handling
@ -261,19 +277,12 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_,
// this is root, can use reduce as most recent point
size_down_in = size_up_out = size_up_reduce;
}
// check if we finished the job of message passing
size_t nfinished = size_down_in;
// can pass message down to childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
if (selecter.CheckWrite(links[i].sock)) {
if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) return kSockError;
}
nfinished = std::min(links[i].size_write, nfinished);
if (i != parent_index && selecter.CheckWrite(links[i].sock)) {
if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) return kSockError;
}
}
// check boundary condition
if (nfinished >= total_size) break;
}
return kSuccess;
}
@ -288,6 +297,7 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_,
AllReduceBase::ReturnType
AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
if (links.size() == 0 || total_size == 0) return kSuccess;
utils::Check(root < world_size, "Broadcast: root should be smaller than world size");
// number of links
const int nlink = static_cast<int>(links.size());
// size of space already read from data
@ -306,13 +316,25 @@ AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
}
// while we have not passed the messages out
while(true) {
bool finished = true;
// select helper
utils::SelectHelper selecter;
for (size_t i = 0; i < links.size(); ++i) {
selecter.WatchRead(links[i].sock);
selecter.WatchWrite(links[i].sock);
for (int i = 0; i < nlink; ++i) {
if (in_link == -2) {
selecter.WatchRead(links[i].sock); finished = false;
}
if (i == in_link && links[i].size_read != total_size) {
selecter.WatchRead(links[i].sock); finished = false;
}
if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
selecter.WatchWrite(links[i].sock); finished = false;
}
selecter.WatchException(links[i].sock);
}
// finish running
if (finished) break;
// select
selecter.Select();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
@ -336,18 +358,12 @@ AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
size_in = links[in_link].size_read;
}
}
size_t nfinished = total_size;
// send data to all out-link
for (int i = 0; i < nlink; ++i) {
if (i != in_link) {
if (selecter.CheckWrite(links[i].sock)) {
if (!links[i].WriteFromArray(sendrecvbuf_, size_in)) return kSockError;
}
nfinished = std::min(nfinished, links[i].size_write);
if (i != in_link && selecter.CheckWrite(links[i].sock)) {
if (!links[i].WriteFromArray(sendrecvbuf_, size_in)) return kSockError;
}
}
// check boundary condition
if (nfinished >= total_size) break;
}
return kSuccess;
}

View File

@ -366,10 +366,10 @@ AllReduceRobust::TryRecoverData(RecoverType role,
// do not need to provide data or receive data, directly exit
if (!req_data) return kSuccess;
}
utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active");
for (int i = 0; i < nlink; ++i) {
links[i].ResetSize();
}
utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active");
if (role == kPassData) {
links[recv_link].InitBuffer(1, size, reduce_buffer_size);
}

View File

@ -70,6 +70,7 @@ int main(int argc, char *argv[]) {
int n = atoi(argv[1]);
sync::Init(argc, argv);
int rank = sync::GetRank();
int nproc = sync::GetWorldSize();
std::string name = sync::GetProcessorName();
test::Mock mock(rank, argv[2], argv[3]);
@ -79,6 +80,10 @@ int main(int argc, char *argv[]) {
utils::LogPrintf("[%d] !!!TestMax pass\n", rank);
TestSum(mock, n);
utils::LogPrintf("[%d] !!!TestSum pass\n", rank);
for (int i = 0; i < nproc; i += nproc / 3) {
TestBcast(mock, n, i);
}
utils::LogPrintf("[%d] !!!TestBcast pass\n", rank);
sync::Finalize();
printf("[%d] all check pass\n", rank);
return 0;