next round try more careful select design

This commit is contained in:
tqchen 2014-11-30 21:07:34 -08:00
parent ecb09a23bc
commit f7928c68a3
3 changed files with 7 additions and 8 deletions

View File

@ -151,7 +151,7 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_,
size_t type_nbytes, size_t type_nbytes,
size_t count, size_t count,
ReduceFunction reducer) { ReduceFunction reducer) {
if (links.size() == 0) return kSuccess; if (links.size() == 0 || count == 0) return kSuccess;
// total size of message // total size of message
const size_t total_size = type_nbytes * count; const size_t total_size = type_nbytes * count;
// number of links // number of links
@ -287,7 +287,7 @@ AllReduceBase::TryAllReduce(void *sendrecvbuf_,
*/ */
AllReduceBase::ReturnType AllReduceBase::ReturnType
AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
if (links.size() == 0) return kSuccess; if (links.size() == 0 || total_size == 0) return kSuccess;
// number of links // number of links
const int nlink = static_cast<int>(links.size()); const int nlink = static_cast<int>(links.size());
// size of space already read from data // size of space already read from data

View File

@ -67,6 +67,7 @@ AllReduceRobust::MsgPassing(const NodeType &node_value,
} }
// select helper // select helper
utils::SelectHelper selecter; utils::SelectHelper selecter;
bool done = (stage == 3);
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
selecter.WatchException(links[i].sock); selecter.WatchException(links[i].sock);
switch (stage) { switch (stage) {
@ -80,12 +81,14 @@ AllReduceRobust::MsgPassing(const NodeType &node_value,
case 3: case 3:
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) { if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
selecter.WatchWrite(links[i].sock); selecter.WatchWrite(links[i].sock);
done = false;
} }
break; break;
default: utils::Error("invalid stage"); default: utils::Error("invalid stage");
} }
} }
// select must return // finish all the stages, and write out message
if (done) break;
selecter.Select(); selecter.Select();
// exception handling // exception handling
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
@ -134,15 +137,11 @@ AllReduceRobust::MsgPassing(const NodeType &node_value,
} }
} }
if (stage == 3) { if (stage == 3) {
bool finished = true;
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) { if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
if (!links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType))) return kSockError; if (!links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType))) return kSockError;
if (links[i].size_write != sizeof(EdgeType)) finished = false;
} }
} }
// finish all the stages
if (finished) break;
} }
} }
return kSuccess; return kSuccess;

View File

@ -352,7 +352,7 @@ AllReduceRobust::TryRecoverData(RecoverType role,
int recv_link, int recv_link,
const std::vector<bool> &req_in) { const std::vector<bool> &req_in) {
// no need to run recovery for zero size message // no need to run recovery for zero size message
if (size == 0) return kSuccess; if (links.size() == 0 || size == 0) return kSuccess;
utils::Assert(req_in.size() == links.size(), "TryRecoverData"); utils::Assert(req_in.size() == links.size(), "TryRecoverData");
const int nlink = static_cast<int>(links.size()); const int nlink = static_cast<int>(links.size());
{ {