smarter select for allreduce and bcast
This commit is contained in:
parent
f7928c68a3
commit
8cef2086f5
@ -179,12 +179,28 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_,
|
|||||||
// while we have not passed the messages out
|
// while we have not passed the messages out
|
||||||
while (true) {
|
while (true) {
|
||||||
// select helper
|
// select helper
|
||||||
|
bool finished = true;
|
||||||
utils::SelectHelper selecter;
|
utils::SelectHelper selecter;
|
||||||
for (size_t i = 0; i < links.size(); ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
selecter.WatchRead(links[i].sock);
|
if (i == parent_index) {
|
||||||
selecter.WatchWrite(links[i].sock);
|
if (size_down_in != total_size) {
|
||||||
|
selecter.WatchRead(links[i].sock); finished = false;
|
||||||
|
}
|
||||||
|
if (size_up_out != total_size) {
|
||||||
|
selecter.WatchWrite(links[i].sock);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (links[i].size_read != total_size) {
|
||||||
|
selecter.WatchRead(links[i].sock);
|
||||||
|
}
|
||||||
|
if (links[i].size_write != total_size) {
|
||||||
|
selecter.WatchWrite(links[i].sock); finished = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
selecter.WatchException(links[i].sock);
|
selecter.WatchException(links[i].sock);
|
||||||
}
|
}
|
||||||
|
// finish runing allreduce
|
||||||
|
if (finished) break;
|
||||||
// select must return
|
// select must return
|
||||||
selecter.Select();
|
selecter.Select();
|
||||||
// exception handling
|
// exception handling
|
||||||
@ -261,19 +277,12 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_,
|
|||||||
// this is root, can use reduce as most recent point
|
// this is root, can use reduce as most recent point
|
||||||
size_down_in = size_up_out = size_up_reduce;
|
size_down_in = size_up_out = size_up_reduce;
|
||||||
}
|
}
|
||||||
// check if we finished the job of message passing
|
|
||||||
size_t nfinished = size_down_in;
|
|
||||||
// can pass message down to childs
|
// can pass message down to childs
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (i != parent_index) {
|
if (i != parent_index && selecter.CheckWrite(links[i].sock)) {
|
||||||
if (selecter.CheckWrite(links[i].sock)) {
|
if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) return kSockError;
|
||||||
if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) return kSockError;
|
|
||||||
}
|
|
||||||
nfinished = std::min(links[i].size_write, nfinished);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// check boundary condition
|
|
||||||
if (nfinished >= total_size) break;
|
|
||||||
}
|
}
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
@ -288,6 +297,7 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_,
|
|||||||
AllReduceBase::ReturnType
|
AllReduceBase::ReturnType
|
||||||
AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||||
if (links.size() == 0 || total_size == 0) return kSuccess;
|
if (links.size() == 0 || total_size == 0) return kSuccess;
|
||||||
|
utils::Check(root < world_size, "Broadcast: root should be smaller than world size");
|
||||||
// number of links
|
// number of links
|
||||||
const int nlink = static_cast<int>(links.size());
|
const int nlink = static_cast<int>(links.size());
|
||||||
// size of space already read from data
|
// size of space already read from data
|
||||||
@ -306,13 +316,25 @@ AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
|||||||
}
|
}
|
||||||
// while we have not passed the messages out
|
// while we have not passed the messages out
|
||||||
while(true) {
|
while(true) {
|
||||||
|
bool finished = true;
|
||||||
// select helper
|
// select helper
|
||||||
utils::SelectHelper selecter;
|
utils::SelectHelper selecter;
|
||||||
for (size_t i = 0; i < links.size(); ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
selecter.WatchRead(links[i].sock);
|
if (in_link == -2) {
|
||||||
selecter.WatchWrite(links[i].sock);
|
selecter.WatchRead(links[i].sock); finished = false;
|
||||||
|
}
|
||||||
|
if (i == in_link && links[i].size_read != total_size) {
|
||||||
|
selecter.WatchRead(links[i].sock); finished = false;
|
||||||
|
}
|
||||||
|
if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
|
||||||
|
selecter.WatchWrite(links[i].sock); finished = false;
|
||||||
|
}
|
||||||
selecter.WatchException(links[i].sock);
|
selecter.WatchException(links[i].sock);
|
||||||
}
|
}
|
||||||
|
// finish running
|
||||||
|
if (finished) break;
|
||||||
|
// select
|
||||||
|
selecter.Select();
|
||||||
// exception handling
|
// exception handling
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
// recive OOB message from some link
|
// recive OOB message from some link
|
||||||
@ -336,18 +358,12 @@ AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
|||||||
size_in = links[in_link].size_read;
|
size_in = links[in_link].size_read;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
size_t nfinished = total_size;
|
|
||||||
// send data to all out-link
|
// send data to all out-link
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (i != in_link) {
|
if (i != in_link && selecter.CheckWrite(links[i].sock)) {
|
||||||
if (selecter.CheckWrite(links[i].sock)) {
|
if (!links[i].WriteFromArray(sendrecvbuf_, size_in)) return kSockError;
|
||||||
if (!links[i].WriteFromArray(sendrecvbuf_, size_in)) return kSockError;
|
|
||||||
}
|
|
||||||
nfinished = std::min(nfinished, links[i].size_write);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// check boundary condition
|
|
||||||
if (nfinished >= total_size) break;
|
|
||||||
}
|
}
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -366,10 +366,10 @@ AllReduceRobust::TryRecoverData(RecoverType role,
|
|||||||
// do not need to provide data or receive data, directly exit
|
// do not need to provide data or receive data, directly exit
|
||||||
if (!req_data) return kSuccess;
|
if (!req_data) return kSuccess;
|
||||||
}
|
}
|
||||||
|
utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active");
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
links[i].ResetSize();
|
links[i].ResetSize();
|
||||||
}
|
}
|
||||||
utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active");
|
|
||||||
if (role == kPassData) {
|
if (role == kPassData) {
|
||||||
links[recv_link].InitBuffer(1, size, reduce_buffer_size);
|
links[recv_link].InitBuffer(1, size, reduce_buffer_size);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -70,6 +70,7 @@ int main(int argc, char *argv[]) {
|
|||||||
int n = atoi(argv[1]);
|
int n = atoi(argv[1]);
|
||||||
sync::Init(argc, argv);
|
sync::Init(argc, argv);
|
||||||
int rank = sync::GetRank();
|
int rank = sync::GetRank();
|
||||||
|
int nproc = sync::GetWorldSize();
|
||||||
std::string name = sync::GetProcessorName();
|
std::string name = sync::GetProcessorName();
|
||||||
|
|
||||||
test::Mock mock(rank, argv[2], argv[3]);
|
test::Mock mock(rank, argv[2], argv[3]);
|
||||||
@ -79,6 +80,10 @@ int main(int argc, char *argv[]) {
|
|||||||
utils::LogPrintf("[%d] !!!TestMax pass\n", rank);
|
utils::LogPrintf("[%d] !!!TestMax pass\n", rank);
|
||||||
TestSum(mock, n);
|
TestSum(mock, n);
|
||||||
utils::LogPrintf("[%d] !!!TestSum pass\n", rank);
|
utils::LogPrintf("[%d] !!!TestSum pass\n", rank);
|
||||||
|
for (int i = 0; i < nproc; i += nproc / 3) {
|
||||||
|
TestBcast(mock, n, i);
|
||||||
|
}
|
||||||
|
utils::LogPrintf("[%d] !!!TestBcast pass\n", rank);
|
||||||
sync::Finalize();
|
sync::Finalize();
|
||||||
printf("[%d] all check pass\n", rank);
|
printf("[%d] all check pass\n", rank);
|
||||||
return 0;
|
return 0;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user