checkin decide request, todo message passing
This commit is contained in:
parent
68f13cd739
commit
7a60cb7f3e
@ -6,6 +6,8 @@
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#define NOMINMAX
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
#include "./utils.h"
|
||||
#include "./engine_robust.h"
|
||||
|
||||
@ -213,6 +215,127 @@ bool AllReduceRobust::CheckAndRecover(ReturnType err_type) {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
/*!
|
||||
* \brief message passing function, used to decide the
|
||||
* shortest distance to the possible source of data
|
||||
* \param node_value a pair of have_data and size
|
||||
* have_data whether current node have data
|
||||
* size gives the size of data, if current node is kHaveData
|
||||
* \param dist_in the shorest to any data source distance in each direction
|
||||
* \param out_index the edge index of output link
|
||||
* \return the shorest distance result of out edge specified by out_index
|
||||
*/
|
||||
inline std::pair<int,size_t>
|
||||
ShortestDist(const std::pair<bool, size_t> &node_value,
|
||||
const std::vector< std::pair<int, size_t> > &dist_in,
|
||||
size_t out_index) {
|
||||
if (node_value.first) {
|
||||
return std::make_pair(1, node_value.second);
|
||||
}
|
||||
size_t size = 0;
|
||||
int res = std::numeric_limits<int>::max();
|
||||
for (size_t i = 0; i < dist_in.size(); ++i) {
|
||||
if (i == out_index) continue;
|
||||
if (dist_in[i].first < res) {
|
||||
res = dist_in[i].first; size = dist_in[i].second;
|
||||
}
|
||||
}
|
||||
return std::make_pair(res, size);
|
||||
}
|
||||
/*!
|
||||
* \brief message passing function, used to decide the
|
||||
* data request from each edge, whether need to request data from certain edge
|
||||
* \param node_value a pair of request_data and best_link
|
||||
* request_data stores whether current node need to request data
|
||||
* best_link gives the best edge index to fetch the data
|
||||
* \param req_in the data request from incoming edges
|
||||
* \param out_index the edge index of output link
|
||||
* \return the request to the output edge
|
||||
*/
|
||||
inline bool DataRequest(const std::pair<bool, int> &node_value,
|
||||
const std::vector<bool> &req_in,
|
||||
size_t out_index) {
|
||||
// whether current node need to request data
|
||||
bool request_data = node_value.first;
|
||||
// which edge index is the best link to request data
|
||||
// can be -1, which means current node contains data
|
||||
const int best_link = node_value.second;
|
||||
if (static_cast<int>(out_index) == best_link) {
|
||||
if (request_data) return true;
|
||||
for (size_t i = 0; i < req_in.size(); ++i) {
|
||||
if (i == out_index) continue;
|
||||
if (req_in[i]) return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
/*!
|
||||
* \brief try to decide the recovery message passing request
|
||||
* \param role the current role of the node
|
||||
* \param p_req_outlink used to store the output link the
|
||||
* current node should recv data from,
|
||||
* this can be -1 or -2,
|
||||
* -1 means current node have the data
|
||||
* -2 means current node do not have data, but also do not need to send/recv data
|
||||
* \param p_req_in used to store the resulting vector, indicating which link we should send the data to
|
||||
* \param p_size used to store the size of the message, for node in state kHaveData,
|
||||
* this size must be set correctly before calling the function
|
||||
* for others, this surves as output parameter
|
||||
*
|
||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
AllReduceRobust::ReturnType
|
||||
AllReduceRobust::TryDecideRequest(AllReduceRobust::RecoverType role,
|
||||
int *p_req_outlink,
|
||||
std::vector<bool> *p_req_in,
|
||||
size_t *p_size) {
|
||||
int best_link = -2;
|
||||
{// get the shortest distance to the request point
|
||||
std::vector< std::pair<int,size_t> > dist_in, dist_out;
|
||||
ReturnType succ = MsgPassing(std::make_pair(role == kHaveData, *p_size),
|
||||
&dist_in, &dist_out, ShortestDist);
|
||||
if (succ != kSuccess) return succ;
|
||||
if (role != kHaveData) {
|
||||
for (size_t i = 0; i < dist_in.size(); ++i) {
|
||||
if (dist_in[i].first != std::numeric_limits<int>::max()) {
|
||||
utils::Check(best_link == -2 || *p_size == dist_in[i].second,
|
||||
"AllReduce size inconsistent");
|
||||
if (best_link == -2 || dist_in[i].first < dist_in[best_link].first) {
|
||||
best_link = static_cast<int>(i);
|
||||
*p_size = dist_in[i].second;
|
||||
}
|
||||
}
|
||||
}
|
||||
utils::Check(best_link != -2, "Too many nodes went down and we cannot recover..");
|
||||
} else {
|
||||
best_link = -1;
|
||||
}
|
||||
}
|
||||
// get the node request
|
||||
std::vector<bool> &req_in = *p_req_in;
|
||||
std::vector<bool> req_out;
|
||||
ReturnType succ = MsgPassing(std::make_pair(role == kRequestData, best_link),
|
||||
&req_in, &req_out, DataRequest);
|
||||
if (succ != kSuccess) return succ;
|
||||
bool need_recv = false;
|
||||
for (size_t i = 0; i < req_in.size(); ++i) {
|
||||
if (req_out[i]) {
|
||||
utils::Assert(!req_in[i], "cannot get and receive request");
|
||||
utils::Assert(static_cast<int>(i) == best_link, "request result inconsistent");
|
||||
need_recv = true;
|
||||
}
|
||||
}
|
||||
if (role == kPassData && !need_recv) {
|
||||
for (size_t i = 0; i < req_in.size(); ++i) {
|
||||
utils::Assert(!req_in[i], "Bug in TryDecideRequest");
|
||||
}
|
||||
*p_req_outlink = 2;
|
||||
} else {
|
||||
*p_req_outlink = best_link;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
/*!
|
||||
* \brief try to load check point
|
||||
*
|
||||
@ -225,7 +348,7 @@ bool AllReduceRobust::CheckAndRecover(ReturnType err_type) {
|
||||
* \sa ReturnType
|
||||
*/
|
||||
AllReduceRobust::ReturnType AllReduceRobust::TryLoadCheckPoint(bool requester) {
|
||||
utils::Error("TryLoadCheckPoint: not implemented");
|
||||
|
||||
return kSuccess;
|
||||
}
|
||||
/*!
|
||||
@ -308,10 +431,12 @@ bool AllReduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
|
||||
} else {
|
||||
// no check point
|
||||
if (act.load_check()) {
|
||||
// all the nodes called load_check, this is an incomplete action
|
||||
if (!act.diff_seq()) return false;
|
||||
// load check have higher priority, do load_check
|
||||
if (!CheckAndRecover(TryLoadCheckPoint(req.load_check()))) continue;
|
||||
// if requested load check, then misson complete
|
||||
if (req.load_check()) return true;
|
||||
if (req.load_check()) return true;
|
||||
} else {
|
||||
// no special flags, no checkpoint, check ack, load_check
|
||||
utils::Assert(act.min_seqno() != ActionSummary::kMaxSeq, "min seq bug");
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
*/
|
||||
#ifndef ALLREDUCE_ENGINE_ROBUST_H
|
||||
#define ALLREDUCE_ENGINE_ROBUST_H
|
||||
#include <vector>
|
||||
#include "./engine.h"
|
||||
#include "./engine_base.h"
|
||||
|
||||
@ -57,6 +58,15 @@ class AllReduceRobust : public AllReduceBase {
|
||||
const static char kResetMark = 97;
|
||||
// and mark for channel cleanup
|
||||
const static char kResetAck = 97;
|
||||
/*! \brief type of roles each node can play during recovery */
|
||||
enum RecoverType {
|
||||
/*! \brief current node have data */
|
||||
kHaveData,
|
||||
/*! \brief current node request data */
|
||||
kRequestData,
|
||||
/*! \brief current node only helps to pass data around */
|
||||
kPassData
|
||||
};
|
||||
/*!
|
||||
* \brief summary of actions proposed in all nodes
|
||||
* this data structure is used to make consensus decision
|
||||
@ -246,6 +256,53 @@ class AllReduceRobust : public AllReduceBase {
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryGetResult(void *buf, size_t size, int seqno, bool requester);
|
||||
/*!
|
||||
* \brief try to decide the recovery message passing request
|
||||
* \param role the current role of the node
|
||||
* \param p_req_outlink used to store the output link the
|
||||
* current node should recv data from,
|
||||
* this can be nonnegative value, -1 or -2,
|
||||
* -1 means current node have the data
|
||||
* -2 means current node do not have data, but also do not need to send/recv data
|
||||
* \param p_req_in used to store the resulting vector, indicating which link we should send the data to
|
||||
* \param p_size used to store the size of the message, for node in state kHaveData,
|
||||
* this size must be set correctly before calling the function
|
||||
* for others, this surves as output parameter
|
||||
*
|
||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryDecideRequest(RecoverType role,
|
||||
int *p_req_outlink,
|
||||
std::vector<bool> *p_req_in,
|
||||
size_t *p_size);
|
||||
/*!
|
||||
* \brief run message passing algorithm on the allreduce tree
|
||||
* the result is edge message stored in p_edge_in and p_edge_out
|
||||
* \param node_value the value associated with current node
|
||||
* \param p_edge_in used to store input message from each of the edge
|
||||
* \param p_edge_out used to store output message from each of the edge
|
||||
* \param func a function that defines the message passing rule
|
||||
* Parameters of func:
|
||||
* - node_value same as node_value in the main function
|
||||
* - edge_in the array of input messages from each edge,
|
||||
* this includes the output edge, which should be excluded
|
||||
* - out_index array the index of output edge, the function should
|
||||
* exclude the output edge when compute the message passing value
|
||||
* Return of func:
|
||||
* the function returns the output message based on the input message and node_value
|
||||
*
|
||||
* \tparam EdgeType type of edge message, must be simple struct
|
||||
* \tparam NodeType type of node value
|
||||
*/
|
||||
template<typename NodeType, typename EdgeType>
|
||||
inline ReturnType MsgPassing(const NodeType &node_value,
|
||||
std::vector<EdgeType> *p_edge_in,
|
||||
std::vector<EdgeType> *p_edge_out,
|
||||
EdgeType (*func) (const NodeType &node_value,
|
||||
const std::vector<EdgeType> &edge_in,
|
||||
size_t out_index)
|
||||
);
|
||||
//---- recovery data structure ----
|
||||
// call sequence counter, records how many calls we made so far
|
||||
// from last call to CheckPoint, LoadCheckPoint
|
||||
@ -254,4 +311,7 @@ class AllReduceRobust : public AllReduceBase {
|
||||
ResultBuffer resbuf;
|
||||
};
|
||||
} // namespace engine
|
||||
// implementation of inline template function
|
||||
#include "./engine_robust-inl.h"
|
||||
|
||||
#endif // ALLREDUCE_ENGINE_ROBUST_H
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user