next round try more careful select design
This commit is contained in:
parent
ecb09a23bc
commit
f7928c68a3
@ -151,7 +151,7 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_,
|
|||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
ReduceFunction reducer) {
|
ReduceFunction reducer) {
|
||||||
if (links.size() == 0) return kSuccess;
|
if (links.size() == 0 || count == 0) return kSuccess;
|
||||||
// total size of message
|
// total size of message
|
||||||
const size_t total_size = type_nbytes * count;
|
const size_t total_size = type_nbytes * count;
|
||||||
// number of links
|
// number of links
|
||||||
@ -287,7 +287,7 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_,
|
|||||||
*/
|
*/
|
||||||
AllReduceBase::ReturnType
|
AllReduceBase::ReturnType
|
||||||
AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||||
if (links.size() == 0) return kSuccess;
|
if (links.size() == 0 || total_size == 0) return kSuccess;
|
||||||
// number of links
|
// number of links
|
||||||
const int nlink = static_cast<int>(links.size());
|
const int nlink = static_cast<int>(links.size());
|
||||||
// size of space already read from data
|
// size of space already read from data
|
||||||
|
|||||||
@ -67,6 +67,7 @@ AllReduceRobust::MsgPassing(const NodeType &node_value,
|
|||||||
}
|
}
|
||||||
// select helper
|
// select helper
|
||||||
utils::SelectHelper selecter;
|
utils::SelectHelper selecter;
|
||||||
|
bool done = (stage == 3);
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
selecter.WatchException(links[i].sock);
|
selecter.WatchException(links[i].sock);
|
||||||
switch (stage) {
|
switch (stage) {
|
||||||
@ -80,12 +81,14 @@ AllReduceRobust::MsgPassing(const NodeType &node_value,
|
|||||||
case 3:
|
case 3:
|
||||||
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
|
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
|
||||||
selecter.WatchWrite(links[i].sock);
|
selecter.WatchWrite(links[i].sock);
|
||||||
|
done = false;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
default: utils::Error("invalid stage");
|
default: utils::Error("invalid stage");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// select must return
|
// finish all the stages, and write out message
|
||||||
|
if (done) break;
|
||||||
selecter.Select();
|
selecter.Select();
|
||||||
// exception handling
|
// exception handling
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
@ -134,15 +137,11 @@ AllReduceRobust::MsgPassing(const NodeType &node_value,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (stage == 3) {
|
if (stage == 3) {
|
||||||
bool finished = true;
|
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
|
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
|
||||||
if (!links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType))) return kSockError;
|
if (!links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType))) return kSockError;
|
||||||
if (links[i].size_write != sizeof(EdgeType)) finished = false;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// finish all the stages
|
|
||||||
if (finished) break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
|
|||||||
@ -352,7 +352,7 @@ AllReduceRobust::TryRecoverData(RecoverType role,
|
|||||||
int recv_link,
|
int recv_link,
|
||||||
const std::vector<bool> &req_in) {
|
const std::vector<bool> &req_in) {
|
||||||
// no need to run recovery for zero size message
|
// no need to run recovery for zero size message
|
||||||
if (size == 0) return kSuccess;
|
if (links.size() == 0 || size == 0) return kSuccess;
|
||||||
utils::Assert(req_in.size() == links.size(), "TryRecoverData");
|
utils::Assert(req_in.size() == links.size(), "TryRecoverData");
|
||||||
const int nlink = static_cast<int>(links.size());
|
const int nlink = static_cast<int>(links.size());
|
||||||
{
|
{
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user