smarter select for allreduce and bcast
This commit is contained in:
parent
f7928c68a3
commit
8cef2086f5
@ -179,12 +179,28 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_,
|
||||
// while we have not passed the messages out
|
||||
while (true) {
|
||||
// select helper
|
||||
bool finished = true;
|
||||
utils::SelectHelper selecter;
|
||||
for (size_t i = 0; i < links.size(); ++i) {
|
||||
selecter.WatchRead(links[i].sock);
|
||||
selecter.WatchWrite(links[i].sock);
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i == parent_index) {
|
||||
if (size_down_in != total_size) {
|
||||
selecter.WatchRead(links[i].sock); finished = false;
|
||||
}
|
||||
if (size_up_out != total_size) {
|
||||
selecter.WatchWrite(links[i].sock);
|
||||
}
|
||||
} else {
|
||||
if (links[i].size_read != total_size) {
|
||||
selecter.WatchRead(links[i].sock);
|
||||
}
|
||||
if (links[i].size_write != total_size) {
|
||||
selecter.WatchWrite(links[i].sock); finished = false;
|
||||
}
|
||||
}
|
||||
selecter.WatchException(links[i].sock);
|
||||
}
|
||||
// finish runing allreduce
|
||||
if (finished) break;
|
||||
// select must return
|
||||
selecter.Select();
|
||||
// exception handling
|
||||
@ -261,19 +277,12 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_,
|
||||
// this is root, can use reduce as most recent point
|
||||
size_down_in = size_up_out = size_up_reduce;
|
||||
}
|
||||
// check if we finished the job of message passing
|
||||
size_t nfinished = size_down_in;
|
||||
// can pass message down to childs
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index) {
|
||||
if (selecter.CheckWrite(links[i].sock)) {
|
||||
if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) return kSockError;
|
||||
}
|
||||
nfinished = std::min(links[i].size_write, nfinished);
|
||||
if (i != parent_index && selecter.CheckWrite(links[i].sock)) {
|
||||
if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) return kSockError;
|
||||
}
|
||||
}
|
||||
// check boundary condition
|
||||
if (nfinished >= total_size) break;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
@ -288,6 +297,7 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_,
|
||||
AllReduceBase::ReturnType
|
||||
AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
if (links.size() == 0 || total_size == 0) return kSuccess;
|
||||
utils::Check(root < world_size, "Broadcast: root should be smaller than world size");
|
||||
// number of links
|
||||
const int nlink = static_cast<int>(links.size());
|
||||
// size of space already read from data
|
||||
@ -306,13 +316,25 @@ AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
}
|
||||
// while we have not passed the messages out
|
||||
while(true) {
|
||||
bool finished = true;
|
||||
// select helper
|
||||
utils::SelectHelper selecter;
|
||||
for (size_t i = 0; i < links.size(); ++i) {
|
||||
selecter.WatchRead(links[i].sock);
|
||||
selecter.WatchWrite(links[i].sock);
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (in_link == -2) {
|
||||
selecter.WatchRead(links[i].sock); finished = false;
|
||||
}
|
||||
if (i == in_link && links[i].size_read != total_size) {
|
||||
selecter.WatchRead(links[i].sock); finished = false;
|
||||
}
|
||||
if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
|
||||
selecter.WatchWrite(links[i].sock); finished = false;
|
||||
}
|
||||
selecter.WatchException(links[i].sock);
|
||||
}
|
||||
// finish running
|
||||
if (finished) break;
|
||||
// select
|
||||
selecter.Select();
|
||||
// exception handling
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
// recive OOB message from some link
|
||||
@ -336,18 +358,12 @@ AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
size_in = links[in_link].size_read;
|
||||
}
|
||||
}
|
||||
size_t nfinished = total_size;
|
||||
// send data to all out-link
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != in_link) {
|
||||
if (selecter.CheckWrite(links[i].sock)) {
|
||||
if (!links[i].WriteFromArray(sendrecvbuf_, size_in)) return kSockError;
|
||||
}
|
||||
nfinished = std::min(nfinished, links[i].size_write);
|
||||
if (i != in_link && selecter.CheckWrite(links[i].sock)) {
|
||||
if (!links[i].WriteFromArray(sendrecvbuf_, size_in)) return kSockError;
|
||||
}
|
||||
}
|
||||
// check boundary condition
|
||||
if (nfinished >= total_size) break;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
@ -366,10 +366,10 @@ AllReduceRobust::TryRecoverData(RecoverType role,
|
||||
// do not need to provide data or receive data, directly exit
|
||||
if (!req_data) return kSuccess;
|
||||
}
|
||||
utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active");
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
links[i].ResetSize();
|
||||
}
|
||||
utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active");
|
||||
if (role == kPassData) {
|
||||
links[recv_link].InitBuffer(1, size, reduce_buffer_size);
|
||||
}
|
||||
|
||||
@ -70,6 +70,7 @@ int main(int argc, char *argv[]) {
|
||||
int n = atoi(argv[1]);
|
||||
sync::Init(argc, argv);
|
||||
int rank = sync::GetRank();
|
||||
int nproc = sync::GetWorldSize();
|
||||
std::string name = sync::GetProcessorName();
|
||||
|
||||
test::Mock mock(rank, argv[2], argv[3]);
|
||||
@ -79,6 +80,10 @@ int main(int argc, char *argv[]) {
|
||||
utils::LogPrintf("[%d] !!!TestMax pass\n", rank);
|
||||
TestSum(mock, n);
|
||||
utils::LogPrintf("[%d] !!!TestSum pass\n", rank);
|
||||
for (int i = 0; i < nproc; i += nproc / 3) {
|
||||
TestBcast(mock, n, i);
|
||||
}
|
||||
utils::LogPrintf("[%d] !!!TestBcast pass\n", rank);
|
||||
sync::Finalize();
|
||||
printf("[%d] all check pass\n", rank);
|
||||
return 0;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user