smarter select for allreduce and bcast

This commit is contained in:
tqchen 2014-11-30 21:31:45 -08:00
parent f7928c68a3
commit 8cef2086f5
3 changed files with 45 additions and 24 deletions

View File

@ -179,12 +179,28 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_,
// while we have not passed the messages out // while we have not passed the messages out
while (true) { while (true) {
// select helper // select helper
bool finished = true;
utils::SelectHelper selecter; utils::SelectHelper selecter;
for (size_t i = 0; i < links.size(); ++i) { for (int i = 0; i < nlink; ++i) {
selecter.WatchRead(links[i].sock); if (i == parent_index) {
selecter.WatchWrite(links[i].sock); if (size_down_in != total_size) {
selecter.WatchRead(links[i].sock); finished = false;
}
if (size_up_out != total_size) {
selecter.WatchWrite(links[i].sock);
}
} else {
if (links[i].size_read != total_size) {
selecter.WatchRead(links[i].sock);
}
if (links[i].size_write != total_size) {
selecter.WatchWrite(links[i].sock); finished = false;
}
}
selecter.WatchException(links[i].sock); selecter.WatchException(links[i].sock);
} }
// finish runing allreduce
if (finished) break;
// select must return // select must return
selecter.Select(); selecter.Select();
// exception handling // exception handling
@ -261,19 +277,12 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_,
// this is root, can use reduce as most recent point // this is root, can use reduce as most recent point
size_down_in = size_up_out = size_up_reduce; size_down_in = size_up_out = size_up_reduce;
} }
// check if we finished the job of message passing
size_t nfinished = size_down_in;
// can pass message down to childs // can pass message down to childs
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (i != parent_index) { if (i != parent_index && selecter.CheckWrite(links[i].sock)) {
if (selecter.CheckWrite(links[i].sock)) { if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) return kSockError;
if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) return kSockError;
}
nfinished = std::min(links[i].size_write, nfinished);
} }
} }
// check boundary condition
if (nfinished >= total_size) break;
} }
return kSuccess; return kSuccess;
} }
@ -288,6 +297,7 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_,
AllReduceBase::ReturnType AllReduceBase::ReturnType
AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
if (links.size() == 0 || total_size == 0) return kSuccess; if (links.size() == 0 || total_size == 0) return kSuccess;
utils::Check(root < world_size, "Broadcast: root should be smaller than world size");
// number of links // number of links
const int nlink = static_cast<int>(links.size()); const int nlink = static_cast<int>(links.size());
// size of space already read from data // size of space already read from data
@ -306,13 +316,25 @@ AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
} }
// while we have not passed the messages out // while we have not passed the messages out
while(true) { while(true) {
bool finished = true;
// select helper // select helper
utils::SelectHelper selecter; utils::SelectHelper selecter;
for (size_t i = 0; i < links.size(); ++i) { for (int i = 0; i < nlink; ++i) {
selecter.WatchRead(links[i].sock); if (in_link == -2) {
selecter.WatchWrite(links[i].sock); selecter.WatchRead(links[i].sock); finished = false;
}
if (i == in_link && links[i].size_read != total_size) {
selecter.WatchRead(links[i].sock); finished = false;
}
if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
selecter.WatchWrite(links[i].sock); finished = false;
}
selecter.WatchException(links[i].sock); selecter.WatchException(links[i].sock);
} }
// finish running
if (finished) break;
// select
selecter.Select();
// exception handling // exception handling
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link // recive OOB message from some link
@ -336,18 +358,12 @@ AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
size_in = links[in_link].size_read; size_in = links[in_link].size_read;
} }
} }
size_t nfinished = total_size;
// send data to all out-link // send data to all out-link
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (i != in_link) { if (i != in_link && selecter.CheckWrite(links[i].sock)) {
if (selecter.CheckWrite(links[i].sock)) { if (!links[i].WriteFromArray(sendrecvbuf_, size_in)) return kSockError;
if (!links[i].WriteFromArray(sendrecvbuf_, size_in)) return kSockError;
}
nfinished = std::min(nfinished, links[i].size_write);
} }
} }
// check boundary condition
if (nfinished >= total_size) break;
} }
return kSuccess; return kSuccess;
} }

View File

@ -366,10 +366,10 @@ AllReduceRobust::TryRecoverData(RecoverType role,
// do not need to provide data or receive data, directly exit // do not need to provide data or receive data, directly exit
if (!req_data) return kSuccess; if (!req_data) return kSuccess;
} }
utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active");
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
links[i].ResetSize(); links[i].ResetSize();
} }
utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active");
if (role == kPassData) { if (role == kPassData) {
links[recv_link].InitBuffer(1, size, reduce_buffer_size); links[recv_link].InitBuffer(1, size, reduce_buffer_size);
} }

View File

@ -70,6 +70,7 @@ int main(int argc, char *argv[]) {
int n = atoi(argv[1]); int n = atoi(argv[1]);
sync::Init(argc, argv); sync::Init(argc, argv);
int rank = sync::GetRank(); int rank = sync::GetRank();
int nproc = sync::GetWorldSize();
std::string name = sync::GetProcessorName(); std::string name = sync::GetProcessorName();
test::Mock mock(rank, argv[2], argv[3]); test::Mock mock(rank, argv[2], argv[3]);
@ -79,6 +80,10 @@ int main(int argc, char *argv[]) {
utils::LogPrintf("[%d] !!!TestMax pass\n", rank); utils::LogPrintf("[%d] !!!TestMax pass\n", rank);
TestSum(mock, n); TestSum(mock, n);
utils::LogPrintf("[%d] !!!TestSum pass\n", rank); utils::LogPrintf("[%d] !!!TestSum pass\n", rank);
for (int i = 0; i < nproc; i += nproc / 3) {
TestBcast(mock, n, i);
}
utils::LogPrintf("[%d] !!!TestBcast pass\n", rank);
sync::Finalize(); sync::Finalize();
printf("[%d] all check pass\n", rank); printf("[%d] all check pass\n", rank);
return 0; return 0;