finish message passing, do a review on msg passing and decide

This commit is contained in:
tqchen 2014-11-30 17:40:30 -08:00
parent 38cd595235
commit d8d648549f
2 changed files with 118 additions and 8 deletions

View File

@ -36,6 +36,113 @@ AllReduceRobust::MsgPassing(const NodeType &node_value,
const std::vector<EdgeType> &edge_in,
size_t out_index)
) {
if (links.size() == 0) return kSuccess;
// number of links
const int nlink = static_cast<int>(links.size());
// initialize the pointers
for (int i = 0; i < nlink; ++i) {
links[i].ResetSize();
}
std::vector<EdgeType> &edge_in = *p_edge_in;
std::vector<EdgeType> &edge_out = *p_edge_out;
edge_in.resize(nlink);
edge_out.resize(nlink);
// stages in the process
// 0: recv messages from childs
// 1: send message to parent
// 2: recv message from parent
// 3: send message to childs
int stage = 0;
// if no childs, no need to reduce
if (nlink == static_cast<int>(parent_index != -1)) {
stage = 1;
}
// while we have not passed the messages out
while (true) {
// for node with no parent, directly do stage 3
if (parent_index == -1) {
utils::Assert(stage != 2 && stage != 1, "invalie stage id");
}
// select helper
utils::SelectHelper selecter;
for (int i = 0; i < nlink; ++i) {
selecter.WatchException(links[i].sock);
switch (stage) {
case 0:
if (i != parent_index && links[i].size_read != sizeof(EdgeType)) {
selecter.WatchRead(links[i].sock);
}
break;
case 1: if (i == parent_index) selecter.WatchWrite(links[i].sock); break;
case 2: if (i == parent_index) selecter.WatchRead(links[i].sock); break;
case 3:
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
selecter.WatchWrite(links[i].sock);
}
break;
default: utils::Error("invalid stage");
}
}
// select must return
selecter.Select();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) return kGetExcept;
}
if (stage == 0) {
bool finished = true;
// read data from childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
if (selecter.CheckRead(links[i].sock)) {
if (!links[i].ReadToArray(&edge_in[i], sizeof(EdgeType))) return kSockError;
}
if (links[i].size_read != sizeof(EdgeType)) finished = false;
}
}
// if no parent, jump to stage 3, otherwise do stage 1
if (finished) {
if (parent_index != -1) {
edge_out[parent_index] = func(node_value, edge_in, parent_index);
stage = 1;
} else {
for (int i = 0; i < nlink; ++i) {
edge_out[i] = func(node_value, edge_in, i);
}
stage = 3;
}
}
}
if (stage == 1) {
const int pid = this->parent_index;
utils::Assert(pid != -1, "MsgPassing invalid stage");
if (!links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType))) return kSockError;
if (links[pid].size_write == sizeof(EdgeType)) stage = 2;
}
if (stage == 2) {
const int pid = this->parent_index;
utils::Assert(pid != -1, "MsgPassing invalid stage");
if (!links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType))) return kSockError;
if (links[pid].size_read == sizeof(EdgeType)) {
for (int i = 0; i < nlink; ++i) {
if (i != pid) edge_out[i] = func(node_value, edge_in, i);
}
stage = 3;
}
}
if (stage == 3) {
bool finished = true;
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
if (!links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType))) return kSockError;
if (links[i].size_write != sizeof(EdgeType)) finished = false;
}
}
// finish all the stages
if (finished) break;
}
}
return kSuccess;
}
} // namespace engine

View File

@ -252,8 +252,8 @@ ShortestDist(const std::pair<bool, size_t> &node_value,
* \param out_index the edge index of output link
* \return the request to the output edge
*/
inline bool DataRequest(const std::pair<bool, int> &node_value,
const std::vector<bool> &req_in,
inline char DataRequest(const std::pair<bool, int> &node_value,
const std::vector<char> &req_in,
size_t out_index) {
// whether current node need to request data
bool request_data = node_value.first;
@ -261,13 +261,13 @@ inline bool DataRequest(const std::pair<bool, int> &node_value,
// can be -1, which means current node contains data
const int best_link = node_value.second;
if (static_cast<int>(out_index) == best_link) {
if (request_data) return true;
if (request_data) return 1;
for (size_t i = 0; i < req_in.size(); ++i) {
if (i == out_index) continue;
if (req_in[i]) return true;
if (req_in[i] != 0) return 1;
}
}
return false;
return 0;
}
/*!
* \brief try to decide the recovery message passing request
@ -313,13 +313,16 @@ AllReduceRobust::TryDecideRequest(AllReduceRobust::RecoverType role,
}
}
// get the node request
std::vector<bool> &req_in = *p_req_in;
std::vector<bool> req_out;
std::vector<char> req_in, req_out;
ReturnType succ = MsgPassing(std::make_pair(role == kRequestData, best_link),
&req_in, &req_out, DataRequest);
if (succ != kSuccess) return succ;
bool need_recv = false;
// set p_req_in
p_req_in->resize(req_in.size());
for (size_t i = 0; i < req_in.size(); ++i) {
// set p_req_in
(*p_req_in)[i] = (req_in[i] != 0);
if (req_out[i]) {
utils::Assert(!req_in[i], "cannot get and receive request");
utils::Assert(static_cast<int>(i) == best_link, "request result inconsistent");
@ -331,7 +334,7 @@ AllReduceRobust::TryDecideRequest(AllReduceRobust::RecoverType role,
utils::Assert(!req_in[i], "Bug in TryDecideRequest");
}
*p_req_outlink = 2;
} else {
} else {
*p_req_outlink = best_link;
}
return kSuccess;