checkin decide request, todo message passing

This commit is contained in:
tqchen 2014-11-30 16:37:26 -08:00
parent 68f13cd739
commit 7a60cb7f3e
2 changed files with 187 additions and 2 deletions

View File

@ -6,6 +6,8 @@
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include <limits>
#include <utility>
#include "./utils.h"
#include "./engine_robust.h"
@ -213,6 +215,127 @@ bool AllReduceRobust::CheckAndRecover(ReturnType err_type) {
}
return false;
}
/*!
* \brief message passing function, used to decide the
* shortest distance to the possible source of data
* \param node_value a pair of have_data and size
* have_data whether current node have data
* size gives the size of data, if current node is kHaveData
* \param dist_in the shorest to any data source distance in each direction
* \param out_index the edge index of output link
* \return the shorest distance result of out edge specified by out_index
*/
inline std::pair<int,size_t>
ShortestDist(const std::pair<bool, size_t> &node_value,
const std::vector< std::pair<int, size_t> > &dist_in,
size_t out_index) {
if (node_value.first) {
return std::make_pair(1, node_value.second);
}
size_t size = 0;
int res = std::numeric_limits<int>::max();
for (size_t i = 0; i < dist_in.size(); ++i) {
if (i == out_index) continue;
if (dist_in[i].first < res) {
res = dist_in[i].first; size = dist_in[i].second;
}
}
return std::make_pair(res, size);
}
/*!
* \brief message passing function, used to decide the
* data request from each edge, whether need to request data from certain edge
* \param node_value a pair of request_data and best_link
* request_data stores whether current node need to request data
* best_link gives the best edge index to fetch the data
* \param req_in the data request from incoming edges
* \param out_index the edge index of output link
* \return the request to the output edge
*/
inline bool DataRequest(const std::pair<bool, int> &node_value,
const std::vector<bool> &req_in,
size_t out_index) {
// whether current node need to request data
bool request_data = node_value.first;
// which edge index is the best link to request data
// can be -1, which means current node contains data
const int best_link = node_value.second;
if (static_cast<int>(out_index) == best_link) {
if (request_data) return true;
for (size_t i = 0; i < req_in.size(); ++i) {
if (i == out_index) continue;
if (req_in[i]) return true;
}
}
return false;
}
/*!
* \brief try to decide the recovery message passing request
* \param role the current role of the node
* \param p_req_outlink used to store the output link the
* current node should recv data from,
* this can be -1 or -2,
* -1 means current node have the data
* -2 means current node do not have data, but also do not need to send/recv data
* \param p_req_in used to store the resulting vector, indicating which link we should send the data to
* \param p_size used to store the size of the message, for node in state kHaveData,
* this size must be set correctly before calling the function
* for others, this surves as output parameter
*
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType
*/
AllReduceRobust::ReturnType
AllReduceRobust::TryDecideRequest(AllReduceRobust::RecoverType role,
int *p_req_outlink,
std::vector<bool> *p_req_in,
size_t *p_size) {
int best_link = -2;
{// get the shortest distance to the request point
std::vector< std::pair<int,size_t> > dist_in, dist_out;
ReturnType succ = MsgPassing(std::make_pair(role == kHaveData, *p_size),
&dist_in, &dist_out, ShortestDist);
if (succ != kSuccess) return succ;
if (role != kHaveData) {
for (size_t i = 0; i < dist_in.size(); ++i) {
if (dist_in[i].first != std::numeric_limits<int>::max()) {
utils::Check(best_link == -2 || *p_size == dist_in[i].second,
"AllReduce size inconsistent");
if (best_link == -2 || dist_in[i].first < dist_in[best_link].first) {
best_link = static_cast<int>(i);
*p_size = dist_in[i].second;
}
}
}
utils::Check(best_link != -2, "Too many nodes went down and we cannot recover..");
} else {
best_link = -1;
}
}
// get the node request
std::vector<bool> &req_in = *p_req_in;
std::vector<bool> req_out;
ReturnType succ = MsgPassing(std::make_pair(role == kRequestData, best_link),
&req_in, &req_out, DataRequest);
if (succ != kSuccess) return succ;
bool need_recv = false;
for (size_t i = 0; i < req_in.size(); ++i) {
if (req_out[i]) {
utils::Assert(!req_in[i], "cannot get and receive request");
utils::Assert(static_cast<int>(i) == best_link, "request result inconsistent");
need_recv = true;
}
}
if (role == kPassData && !need_recv) {
for (size_t i = 0; i < req_in.size(); ++i) {
utils::Assert(!req_in[i], "Bug in TryDecideRequest");
}
*p_req_outlink = 2;
} else {
*p_req_outlink = best_link;
}
return kSuccess;
}
/*!
* \brief try to load check point
*
@ -225,7 +348,7 @@ bool AllReduceRobust::CheckAndRecover(ReturnType err_type) {
* \sa ReturnType
*/
AllReduceRobust::ReturnType AllReduceRobust::TryLoadCheckPoint(bool requester) {
utils::Error("TryLoadCheckPoint: not implemented");
return kSuccess;
}
/*!
@ -308,10 +431,12 @@ bool AllReduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
} else {
// no check point
if (act.load_check()) {
// all the nodes called load_check, this is an incomplete action
if (!act.diff_seq()) return false;
// load check have higher priority, do load_check
if (!CheckAndRecover(TryLoadCheckPoint(req.load_check()))) continue;
// if requested load check, then misson complete
if (req.load_check()) return true;
if (req.load_check()) return true;
} else {
// no special flags, no checkpoint, check ack, load_check
utils::Assert(act.min_seqno() != ActionSummary::kMaxSeq, "min seq bug");

View File

@ -9,6 +9,7 @@
*/
#ifndef ALLREDUCE_ENGINE_ROBUST_H
#define ALLREDUCE_ENGINE_ROBUST_H
#include <vector>
#include "./engine.h"
#include "./engine_base.h"
@ -57,6 +58,15 @@ class AllReduceRobust : public AllReduceBase {
const static char kResetMark = 97;
// and mark for channel cleanup
const static char kResetAck = 97;
/*! \brief type of roles each node can play during recovery */
enum RecoverType {
/*! \brief current node have data */
kHaveData,
/*! \brief current node request data */
kRequestData,
/*! \brief current node only helps to pass data around */
kPassData
};
/*!
* \brief summary of actions proposed in all nodes
* this data structure is used to make consensus decision
@ -246,6 +256,53 @@ class AllReduceRobust : public AllReduceBase {
* \sa ReturnType
*/
ReturnType TryGetResult(void *buf, size_t size, int seqno, bool requester);
/*!
* \brief try to decide the recovery message passing request
* \param role the current role of the node
* \param p_req_outlink used to store the output link the
* current node should recv data from,
* this can be nonnegative value, -1 or -2,
* -1 means current node have the data
* -2 means current node do not have data, but also do not need to send/recv data
* \param p_req_in used to store the resulting vector, indicating which link we should send the data to
* \param p_size used to store the size of the message, for node in state kHaveData,
* this size must be set correctly before calling the function
* for others, this surves as output parameter
*
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryDecideRequest(RecoverType role,
int *p_req_outlink,
std::vector<bool> *p_req_in,
size_t *p_size);
/*!
* \brief run message passing algorithm on the allreduce tree
* the result is edge message stored in p_edge_in and p_edge_out
* \param node_value the value associated with current node
* \param p_edge_in used to store input message from each of the edge
* \param p_edge_out used to store output message from each of the edge
* \param func a function that defines the message passing rule
* Parameters of func:
* - node_value same as node_value in the main function
* - edge_in the array of input messages from each edge,
* this includes the output edge, which should be excluded
* - out_index array the index of output edge, the function should
* exclude the output edge when compute the message passing value
* Return of func:
* the function returns the output message based on the input message and node_value
*
* \tparam EdgeType type of edge message, must be simple struct
* \tparam NodeType type of node value
*/
template<typename NodeType, typename EdgeType>
inline ReturnType MsgPassing(const NodeType &node_value,
std::vector<EdgeType> *p_edge_in,
std::vector<EdgeType> *p_edge_out,
EdgeType (*func) (const NodeType &node_value,
const std::vector<EdgeType> &edge_in,
size_t out_index)
);
//---- recovery data structure ----
// call sequence counter, records how many calls we made so far
// from last call to CheckPoint, LoadCheckPoint
@ -254,4 +311,7 @@ class AllReduceRobust : public AllReduceBase {
ResultBuffer resbuf;
};
} // namespace engine
// implementation of inline template function
#include "./engine_robust-inl.h"
#endif // ALLREDUCE_ENGINE_ROBUST_H