checkin decide request, todo message passing
This commit is contained in:
parent
68f13cd739
commit
7a60cb7f3e
@ -6,6 +6,8 @@
|
|||||||
#define _CRT_SECURE_NO_WARNINGS
|
#define _CRT_SECURE_NO_WARNINGS
|
||||||
#define _CRT_SECURE_NO_DEPRECATE
|
#define _CRT_SECURE_NO_DEPRECATE
|
||||||
#define NOMINMAX
|
#define NOMINMAX
|
||||||
|
#include <limits>
|
||||||
|
#include <utility>
|
||||||
#include "./utils.h"
|
#include "./utils.h"
|
||||||
#include "./engine_robust.h"
|
#include "./engine_robust.h"
|
||||||
|
|
||||||
@ -213,6 +215,127 @@ bool AllReduceRobust::CheckAndRecover(ReturnType err_type) {
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
/*!
|
||||||
|
* \brief message passing function, used to decide the
|
||||||
|
* shortest distance to the possible source of data
|
||||||
|
* \param node_value a pair of have_data and size
|
||||||
|
* have_data whether current node have data
|
||||||
|
* size gives the size of data, if current node is kHaveData
|
||||||
|
* \param dist_in the shorest to any data source distance in each direction
|
||||||
|
* \param out_index the edge index of output link
|
||||||
|
* \return the shorest distance result of out edge specified by out_index
|
||||||
|
*/
|
||||||
|
inline std::pair<int,size_t>
|
||||||
|
ShortestDist(const std::pair<bool, size_t> &node_value,
|
||||||
|
const std::vector< std::pair<int, size_t> > &dist_in,
|
||||||
|
size_t out_index) {
|
||||||
|
if (node_value.first) {
|
||||||
|
return std::make_pair(1, node_value.second);
|
||||||
|
}
|
||||||
|
size_t size = 0;
|
||||||
|
int res = std::numeric_limits<int>::max();
|
||||||
|
for (size_t i = 0; i < dist_in.size(); ++i) {
|
||||||
|
if (i == out_index) continue;
|
||||||
|
if (dist_in[i].first < res) {
|
||||||
|
res = dist_in[i].first; size = dist_in[i].second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_pair(res, size);
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief message passing function, used to decide the
|
||||||
|
* data request from each edge, whether need to request data from certain edge
|
||||||
|
* \param node_value a pair of request_data and best_link
|
||||||
|
* request_data stores whether current node need to request data
|
||||||
|
* best_link gives the best edge index to fetch the data
|
||||||
|
* \param req_in the data request from incoming edges
|
||||||
|
* \param out_index the edge index of output link
|
||||||
|
* \return the request to the output edge
|
||||||
|
*/
|
||||||
|
inline bool DataRequest(const std::pair<bool, int> &node_value,
|
||||||
|
const std::vector<bool> &req_in,
|
||||||
|
size_t out_index) {
|
||||||
|
// whether current node need to request data
|
||||||
|
bool request_data = node_value.first;
|
||||||
|
// which edge index is the best link to request data
|
||||||
|
// can be -1, which means current node contains data
|
||||||
|
const int best_link = node_value.second;
|
||||||
|
if (static_cast<int>(out_index) == best_link) {
|
||||||
|
if (request_data) return true;
|
||||||
|
for (size_t i = 0; i < req_in.size(); ++i) {
|
||||||
|
if (i == out_index) continue;
|
||||||
|
if (req_in[i]) return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief try to decide the recovery message passing request
|
||||||
|
* \param role the current role of the node
|
||||||
|
* \param p_req_outlink used to store the output link the
|
||||||
|
* current node should recv data from,
|
||||||
|
* this can be -1 or -2,
|
||||||
|
* -1 means current node have the data
|
||||||
|
* -2 means current node do not have data, but also do not need to send/recv data
|
||||||
|
* \param p_req_in used to store the resulting vector, indicating which link we should send the data to
|
||||||
|
* \param p_size used to store the size of the message, for node in state kHaveData,
|
||||||
|
* this size must be set correctly before calling the function
|
||||||
|
* for others, this surves as output parameter
|
||||||
|
*
|
||||||
|
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||||
|
* \sa ReturnType
|
||||||
|
*/
|
||||||
|
AllReduceRobust::ReturnType
|
||||||
|
AllReduceRobust::TryDecideRequest(AllReduceRobust::RecoverType role,
|
||||||
|
int *p_req_outlink,
|
||||||
|
std::vector<bool> *p_req_in,
|
||||||
|
size_t *p_size) {
|
||||||
|
int best_link = -2;
|
||||||
|
{// get the shortest distance to the request point
|
||||||
|
std::vector< std::pair<int,size_t> > dist_in, dist_out;
|
||||||
|
ReturnType succ = MsgPassing(std::make_pair(role == kHaveData, *p_size),
|
||||||
|
&dist_in, &dist_out, ShortestDist);
|
||||||
|
if (succ != kSuccess) return succ;
|
||||||
|
if (role != kHaveData) {
|
||||||
|
for (size_t i = 0; i < dist_in.size(); ++i) {
|
||||||
|
if (dist_in[i].first != std::numeric_limits<int>::max()) {
|
||||||
|
utils::Check(best_link == -2 || *p_size == dist_in[i].second,
|
||||||
|
"AllReduce size inconsistent");
|
||||||
|
if (best_link == -2 || dist_in[i].first < dist_in[best_link].first) {
|
||||||
|
best_link = static_cast<int>(i);
|
||||||
|
*p_size = dist_in[i].second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
utils::Check(best_link != -2, "Too many nodes went down and we cannot recover..");
|
||||||
|
} else {
|
||||||
|
best_link = -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// get the node request
|
||||||
|
std::vector<bool> &req_in = *p_req_in;
|
||||||
|
std::vector<bool> req_out;
|
||||||
|
ReturnType succ = MsgPassing(std::make_pair(role == kRequestData, best_link),
|
||||||
|
&req_in, &req_out, DataRequest);
|
||||||
|
if (succ != kSuccess) return succ;
|
||||||
|
bool need_recv = false;
|
||||||
|
for (size_t i = 0; i < req_in.size(); ++i) {
|
||||||
|
if (req_out[i]) {
|
||||||
|
utils::Assert(!req_in[i], "cannot get and receive request");
|
||||||
|
utils::Assert(static_cast<int>(i) == best_link, "request result inconsistent");
|
||||||
|
need_recv = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (role == kPassData && !need_recv) {
|
||||||
|
for (size_t i = 0; i < req_in.size(); ++i) {
|
||||||
|
utils::Assert(!req_in[i], "Bug in TryDecideRequest");
|
||||||
|
}
|
||||||
|
*p_req_outlink = 2;
|
||||||
|
} else {
|
||||||
|
*p_req_outlink = best_link;
|
||||||
|
}
|
||||||
|
return kSuccess;
|
||||||
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief try to load check point
|
* \brief try to load check point
|
||||||
*
|
*
|
||||||
@ -225,7 +348,7 @@ bool AllReduceRobust::CheckAndRecover(ReturnType err_type) {
|
|||||||
* \sa ReturnType
|
* \sa ReturnType
|
||||||
*/
|
*/
|
||||||
AllReduceRobust::ReturnType AllReduceRobust::TryLoadCheckPoint(bool requester) {
|
AllReduceRobust::ReturnType AllReduceRobust::TryLoadCheckPoint(bool requester) {
|
||||||
utils::Error("TryLoadCheckPoint: not implemented");
|
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -308,6 +431,8 @@ bool AllReduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
|
|||||||
} else {
|
} else {
|
||||||
// no check point
|
// no check point
|
||||||
if (act.load_check()) {
|
if (act.load_check()) {
|
||||||
|
// all the nodes called load_check, this is an incomplete action
|
||||||
|
if (!act.diff_seq()) return false;
|
||||||
// load check have higher priority, do load_check
|
// load check have higher priority, do load_check
|
||||||
if (!CheckAndRecover(TryLoadCheckPoint(req.load_check()))) continue;
|
if (!CheckAndRecover(TryLoadCheckPoint(req.load_check()))) continue;
|
||||||
// if requested load check, then misson complete
|
// if requested load check, then misson complete
|
||||||
|
|||||||
@ -9,6 +9,7 @@
|
|||||||
*/
|
*/
|
||||||
#ifndef ALLREDUCE_ENGINE_ROBUST_H
|
#ifndef ALLREDUCE_ENGINE_ROBUST_H
|
||||||
#define ALLREDUCE_ENGINE_ROBUST_H
|
#define ALLREDUCE_ENGINE_ROBUST_H
|
||||||
|
#include <vector>
|
||||||
#include "./engine.h"
|
#include "./engine.h"
|
||||||
#include "./engine_base.h"
|
#include "./engine_base.h"
|
||||||
|
|
||||||
@ -57,6 +58,15 @@ class AllReduceRobust : public AllReduceBase {
|
|||||||
const static char kResetMark = 97;
|
const static char kResetMark = 97;
|
||||||
// and mark for channel cleanup
|
// and mark for channel cleanup
|
||||||
const static char kResetAck = 97;
|
const static char kResetAck = 97;
|
||||||
|
/*! \brief type of roles each node can play during recovery */
|
||||||
|
enum RecoverType {
|
||||||
|
/*! \brief current node have data */
|
||||||
|
kHaveData,
|
||||||
|
/*! \brief current node request data */
|
||||||
|
kRequestData,
|
||||||
|
/*! \brief current node only helps to pass data around */
|
||||||
|
kPassData
|
||||||
|
};
|
||||||
/*!
|
/*!
|
||||||
* \brief summary of actions proposed in all nodes
|
* \brief summary of actions proposed in all nodes
|
||||||
* this data structure is used to make consensus decision
|
* this data structure is used to make consensus decision
|
||||||
@ -246,6 +256,53 @@ class AllReduceRobust : public AllReduceBase {
|
|||||||
* \sa ReturnType
|
* \sa ReturnType
|
||||||
*/
|
*/
|
||||||
ReturnType TryGetResult(void *buf, size_t size, int seqno, bool requester);
|
ReturnType TryGetResult(void *buf, size_t size, int seqno, bool requester);
|
||||||
|
/*!
|
||||||
|
* \brief try to decide the recovery message passing request
|
||||||
|
* \param role the current role of the node
|
||||||
|
* \param p_req_outlink used to store the output link the
|
||||||
|
* current node should recv data from,
|
||||||
|
* this can be nonnegative value, -1 or -2,
|
||||||
|
* -1 means current node have the data
|
||||||
|
* -2 means current node do not have data, but also do not need to send/recv data
|
||||||
|
* \param p_req_in used to store the resulting vector, indicating which link we should send the data to
|
||||||
|
* \param p_size used to store the size of the message, for node in state kHaveData,
|
||||||
|
* this size must be set correctly before calling the function
|
||||||
|
* for others, this surves as output parameter
|
||||||
|
*
|
||||||
|
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||||
|
* \sa ReturnType
|
||||||
|
*/
|
||||||
|
ReturnType TryDecideRequest(RecoverType role,
|
||||||
|
int *p_req_outlink,
|
||||||
|
std::vector<bool> *p_req_in,
|
||||||
|
size_t *p_size);
|
||||||
|
/*!
|
||||||
|
* \brief run message passing algorithm on the allreduce tree
|
||||||
|
* the result is edge message stored in p_edge_in and p_edge_out
|
||||||
|
* \param node_value the value associated with current node
|
||||||
|
* \param p_edge_in used to store input message from each of the edge
|
||||||
|
* \param p_edge_out used to store output message from each of the edge
|
||||||
|
* \param func a function that defines the message passing rule
|
||||||
|
* Parameters of func:
|
||||||
|
* - node_value same as node_value in the main function
|
||||||
|
* - edge_in the array of input messages from each edge,
|
||||||
|
* this includes the output edge, which should be excluded
|
||||||
|
* - out_index array the index of output edge, the function should
|
||||||
|
* exclude the output edge when compute the message passing value
|
||||||
|
* Return of func:
|
||||||
|
* the function returns the output message based on the input message and node_value
|
||||||
|
*
|
||||||
|
* \tparam EdgeType type of edge message, must be simple struct
|
||||||
|
* \tparam NodeType type of node value
|
||||||
|
*/
|
||||||
|
template<typename NodeType, typename EdgeType>
|
||||||
|
inline ReturnType MsgPassing(const NodeType &node_value,
|
||||||
|
std::vector<EdgeType> *p_edge_in,
|
||||||
|
std::vector<EdgeType> *p_edge_out,
|
||||||
|
EdgeType (*func) (const NodeType &node_value,
|
||||||
|
const std::vector<EdgeType> &edge_in,
|
||||||
|
size_t out_index)
|
||||||
|
);
|
||||||
//---- recovery data structure ----
|
//---- recovery data structure ----
|
||||||
// call sequence counter, records how many calls we made so far
|
// call sequence counter, records how many calls we made so far
|
||||||
// from last call to CheckPoint, LoadCheckPoint
|
// from last call to CheckPoint, LoadCheckPoint
|
||||||
@ -254,4 +311,7 @@ class AllReduceRobust : public AllReduceBase {
|
|||||||
ResultBuffer resbuf;
|
ResultBuffer resbuf;
|
||||||
};
|
};
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
|
// implementation of inline template function
|
||||||
|
#include "./engine_robust-inl.h"
|
||||||
|
|
||||||
#endif // ALLREDUCE_ENGINE_ROBUST_H
|
#endif // ALLREDUCE_ENGINE_ROBUST_H
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user