finish message passing, do a review on msg passing and decide
This commit is contained in:
parent
38cd595235
commit
d8d648549f
@ -36,6 +36,113 @@ AllReduceRobust::MsgPassing(const NodeType &node_value,
|
||||
const std::vector<EdgeType> &edge_in,
|
||||
size_t out_index)
|
||||
) {
|
||||
if (links.size() == 0) return kSuccess;
|
||||
// number of links
|
||||
const int nlink = static_cast<int>(links.size());
|
||||
// initialize the pointers
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
links[i].ResetSize();
|
||||
}
|
||||
std::vector<EdgeType> &edge_in = *p_edge_in;
|
||||
std::vector<EdgeType> &edge_out = *p_edge_out;
|
||||
edge_in.resize(nlink);
|
||||
edge_out.resize(nlink);
|
||||
// stages in the process
|
||||
// 0: recv messages from childs
|
||||
// 1: send message to parent
|
||||
// 2: recv message from parent
|
||||
// 3: send message to childs
|
||||
int stage = 0;
|
||||
// if no childs, no need to reduce
|
||||
if (nlink == static_cast<int>(parent_index != -1)) {
|
||||
stage = 1;
|
||||
}
|
||||
// while we have not passed the messages out
|
||||
while (true) {
|
||||
// for node with no parent, directly do stage 3
|
||||
if (parent_index == -1) {
|
||||
utils::Assert(stage != 2 && stage != 1, "invalie stage id");
|
||||
}
|
||||
// select helper
|
||||
utils::SelectHelper selecter;
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
selecter.WatchException(links[i].sock);
|
||||
switch (stage) {
|
||||
case 0:
|
||||
if (i != parent_index && links[i].size_read != sizeof(EdgeType)) {
|
||||
selecter.WatchRead(links[i].sock);
|
||||
}
|
||||
break;
|
||||
case 1: if (i == parent_index) selecter.WatchWrite(links[i].sock); break;
|
||||
case 2: if (i == parent_index) selecter.WatchRead(links[i].sock); break;
|
||||
case 3:
|
||||
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
|
||||
selecter.WatchWrite(links[i].sock);
|
||||
}
|
||||
break;
|
||||
default: utils::Error("invalid stage");
|
||||
}
|
||||
}
|
||||
// select must return
|
||||
selecter.Select();
|
||||
// exception handling
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
// recive OOB message from some link
|
||||
if (selecter.CheckExcept(links[i].sock)) return kGetExcept;
|
||||
}
|
||||
if (stage == 0) {
|
||||
bool finished = true;
|
||||
// read data from childs
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index) {
|
||||
if (selecter.CheckRead(links[i].sock)) {
|
||||
if (!links[i].ReadToArray(&edge_in[i], sizeof(EdgeType))) return kSockError;
|
||||
}
|
||||
if (links[i].size_read != sizeof(EdgeType)) finished = false;
|
||||
}
|
||||
}
|
||||
// if no parent, jump to stage 3, otherwise do stage 1
|
||||
if (finished) {
|
||||
if (parent_index != -1) {
|
||||
edge_out[parent_index] = func(node_value, edge_in, parent_index);
|
||||
stage = 1;
|
||||
} else {
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
edge_out[i] = func(node_value, edge_in, i);
|
||||
}
|
||||
stage = 3;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (stage == 1) {
|
||||
const int pid = this->parent_index;
|
||||
utils::Assert(pid != -1, "MsgPassing invalid stage");
|
||||
if (!links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType))) return kSockError;
|
||||
if (links[pid].size_write == sizeof(EdgeType)) stage = 2;
|
||||
}
|
||||
if (stage == 2) {
|
||||
const int pid = this->parent_index;
|
||||
utils::Assert(pid != -1, "MsgPassing invalid stage");
|
||||
if (!links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType))) return kSockError;
|
||||
if (links[pid].size_read == sizeof(EdgeType)) {
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != pid) edge_out[i] = func(node_value, edge_in, i);
|
||||
}
|
||||
stage = 3;
|
||||
}
|
||||
}
|
||||
if (stage == 3) {
|
||||
bool finished = true;
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
|
||||
if (!links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType))) return kSockError;
|
||||
if (links[i].size_write != sizeof(EdgeType)) finished = false;
|
||||
}
|
||||
}
|
||||
// finish all the stages
|
||||
if (finished) break;
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
} // namespace engine
|
||||
|
||||
@ -252,8 +252,8 @@ ShortestDist(const std::pair<bool, size_t> &node_value,
|
||||
* \param out_index the edge index of output link
|
||||
* \return the request to the output edge
|
||||
*/
|
||||
inline bool DataRequest(const std::pair<bool, int> &node_value,
|
||||
const std::vector<bool> &req_in,
|
||||
inline char DataRequest(const std::pair<bool, int> &node_value,
|
||||
const std::vector<char> &req_in,
|
||||
size_t out_index) {
|
||||
// whether current node need to request data
|
||||
bool request_data = node_value.first;
|
||||
@ -261,13 +261,13 @@ inline bool DataRequest(const std::pair<bool, int> &node_value,
|
||||
// can be -1, which means current node contains data
|
||||
const int best_link = node_value.second;
|
||||
if (static_cast<int>(out_index) == best_link) {
|
||||
if (request_data) return true;
|
||||
if (request_data) return 1;
|
||||
for (size_t i = 0; i < req_in.size(); ++i) {
|
||||
if (i == out_index) continue;
|
||||
if (req_in[i]) return true;
|
||||
if (req_in[i] != 0) return 1;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return 0;
|
||||
}
|
||||
/*!
|
||||
* \brief try to decide the recovery message passing request
|
||||
@ -313,13 +313,16 @@ AllReduceRobust::TryDecideRequest(AllReduceRobust::RecoverType role,
|
||||
}
|
||||
}
|
||||
// get the node request
|
||||
std::vector<bool> &req_in = *p_req_in;
|
||||
std::vector<bool> req_out;
|
||||
std::vector<char> req_in, req_out;
|
||||
ReturnType succ = MsgPassing(std::make_pair(role == kRequestData, best_link),
|
||||
&req_in, &req_out, DataRequest);
|
||||
if (succ != kSuccess) return succ;
|
||||
bool need_recv = false;
|
||||
// set p_req_in
|
||||
p_req_in->resize(req_in.size());
|
||||
for (size_t i = 0; i < req_in.size(); ++i) {
|
||||
// set p_req_in
|
||||
(*p_req_in)[i] = (req_in[i] != 0);
|
||||
if (req_out[i]) {
|
||||
utils::Assert(!req_in[i], "cannot get and receive request");
|
||||
utils::Assert(static_cast<int>(i) == best_link, "request result inconsistent");
|
||||
@ -331,7 +334,7 @@ AllReduceRobust::TryDecideRequest(AllReduceRobust::RecoverType role,
|
||||
utils::Assert(!req_in[i], "Bug in TryDecideRequest");
|
||||
}
|
||||
*p_req_outlink = 2;
|
||||
} else {
|
||||
} else {
|
||||
*p_req_outlink = best_link;
|
||||
}
|
||||
return kSuccess;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user